Skip to content

Commit f1ad107

Browse files
committed
fix(streaming): invert logic for assistant stream parsing
1 parent 37d0b25 commit f1ad107

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

src/openai/_streaming.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __stream__(self) -> Iterator[_T]:
5959
if sse.data.startswith("[DONE]"):
6060
break
6161

62-
if sse.event is None or sse.event.startswith("response.") or sse.event.startswith("transcript."):
62+
if sse.event == "error":
6363
data = sse.json()
6464
if is_mapping(data) and data.get("error"):
6565
message = None
@@ -75,12 +75,13 @@ def __stream__(self) -> Iterator[_T]:
7575
body=data["error"],
7676
)
7777

78-
yield process_data(data=data, cast_to=cast_to, response=response)
7978

79+
if sse.event and sse.event.startswith('thread.'):
80+
# the assistants API uses a different event shape structure
81+
yield process_data(data={"data": sse.json(), "event": sse.event}, cast_to=cast_to, response=response)
8082
else:
8183
data = sse.json()
82-
83-
if sse.event == "error" and is_mapping(data) and data.get("error"):
84+
if is_mapping(data) and data.get("error"):
8485
message = None
8586
error = data.get("error")
8687
if is_mapping(error):
@@ -94,7 +95,7 @@ def __stream__(self) -> Iterator[_T]:
9495
body=data["error"],
9596
)
9697

97-
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
98+
yield process_data(data=data, cast_to=cast_to, response=response)
9899

99100
# Ensure the entire stream is consumed
100101
for _sse in iterator:
@@ -161,7 +162,7 @@ async def __stream__(self) -> AsyncIterator[_T]:
161162
if sse.data.startswith("[DONE]"):
162163
break
163164

164-
if sse.event is None or sse.event.startswith("response.") or sse.event.startswith("transcript."):
165+
if sse.event == "error":
165166
data = sse.json()
166167
if is_mapping(data) and data.get("error"):
167168
message = None
@@ -177,12 +178,13 @@ async def __stream__(self) -> AsyncIterator[_T]:
177178
body=data["error"],
178179
)
179180

180-
yield process_data(data=data, cast_to=cast_to, response=response)
181181

182+
if sse.event and sse.event.startswith('thread.'):
183+
# the assistants API uses a different event shape structure
184+
yield process_data(data={"data": sse.json(), "event": sse.event}, cast_to=cast_to, response=response)
182185
else:
183186
data = sse.json()
184-
185-
if sse.event == "error" and is_mapping(data) and data.get("error"):
187+
if is_mapping(data) and data.get("error"):
186188
message = None
187189
error = data.get("error")
188190
if is_mapping(error):
@@ -196,7 +198,7 @@ async def __stream__(self) -> AsyncIterator[_T]:
196198
body=data["error"],
197199
)
198200

199-
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
201+
yield process_data(data=data, cast_to=cast_to, response=response)
200202

201203
# Ensure the entire stream is consumed
202204
async for _sse in iterator:

0 commit comments

Comments
 (0)