|
17 | 17 | AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) |
18 | 18 |
|
19 | 19 |
|
| 20 | +class StreamedResponseHandler: |
| 21 | + """Handles streaming HTTP responses by accumulating chunks until complete |
| 22 | + messages are available.""" |
| 23 | + |
| 24 | + def __init__(self): |
| 25 | + self.buffer = "" |
| 26 | + |
| 27 | + def add_chunk(self, chunk_bytes: bytes) -> list[str]: |
| 28 | + """Add a chunk of bytes to the buffer and return any complete |
| 29 | + messages.""" |
| 30 | + chunk_str = chunk_bytes.decode("utf-8") |
| 31 | + self.buffer += chunk_str |
| 32 | + |
| 33 | + messages = [] |
| 34 | + |
| 35 | + # Split by double newlines (SSE message separator) |
| 36 | + while "\n\n" in self.buffer: |
| 37 | + message, self.buffer = self.buffer.split("\n\n", 1) |
| 38 | + message = message.strip() |
| 39 | + if message: |
| 40 | + messages.append(message) |
| 41 | + |
| 42 | + # if self.buffer is not empty, check if it is a complete message |
| 43 | + # by removing data: prefix and check if it is a valid JSON |
| 44 | + if self.buffer.startswith("data: "): |
| 45 | + message_content = self.buffer.removeprefix("data: ").strip() |
| 46 | + if message_content == "[DONE]": |
| 47 | + messages.append(self.buffer.strip()) |
| 48 | + self.buffer = "" |
| 49 | + elif message_content: |
| 50 | + try: |
| 51 | + json.loads(message_content) |
| 52 | + messages.append(self.buffer.strip()) |
| 53 | + self.buffer = "" |
| 54 | + except json.JSONDecodeError: |
| 55 | + # Incomplete JSON, wait for more chunks. |
| 56 | + pass |
| 57 | + |
| 58 | + return messages |
| 59 | + |
| 60 | + |
20 | 61 | @dataclass |
21 | 62 | class RequestFuncInput: |
22 | 63 | """The input for the request function.""" |
@@ -102,46 +143,50 @@ async def async_request_openai_completions( |
102 | 143 | headers=headers) as response: |
103 | 144 | if response.status == 200: |
104 | 145 | first_chunk_received = False |
105 | | - async for chunk_bytes in response.content: |
| 146 | + handler = StreamedResponseHandler() |
| 147 | + |
| 148 | + async for chunk_bytes in response.content.iter_any(): |
106 | 149 | chunk_bytes = chunk_bytes.strip() |
107 | 150 | if not chunk_bytes: |
108 | 151 | continue |
109 | | - chunk_bytes = chunk_bytes.decode("utf-8") |
110 | | - # NOTE: SSE comments (often used as pings) start with |
111 | | - # a colon. These are not JSON data payload and should |
112 | | - # be skipped. |
113 | | - if chunk_bytes.startswith(":"): |
114 | | - continue |
115 | 152 |
|
116 | | - chunk = chunk_bytes.removeprefix("data: ") |
| 153 | + messages = handler.add_chunk(chunk_bytes) |
| 154 | + for message in messages: |
| 155 | + # NOTE: SSE comments (often used as pings) start with |
| 156 | + # a colon. These are not JSON data payload and should |
| 157 | + # be skipped. |
| 158 | + if message.startswith(":"): |
| 159 | + continue |
117 | 160 |
|
118 | | - if chunk != "[DONE]": |
119 | | - data = json.loads(chunk) |
| 161 | + chunk = message.removeprefix("data: ") |
120 | 162 |
|
121 | | - # NOTE: Some completion API might have a last |
122 | | - # usage summary response without a token so we |
123 | | - # want to check a token was generated |
124 | | - if choices := data.get("choices"): |
125 | | - # Note that text could be empty here |
126 | | - # e.g. for special tokens |
127 | | - text = choices[0].get("text") |
128 | | - timestamp = time.perf_counter() |
129 | | - # First token |
130 | | - if not first_chunk_received: |
131 | | - first_chunk_received = True |
132 | | - ttft = time.perf_counter() - st |
133 | | - output.ttft = ttft |
| 163 | + if chunk != "[DONE]": |
| 164 | + data = json.loads(chunk) |
134 | 165 |
|
135 | | - # Decoding phase |
136 | | - else: |
137 | | - output.itl.append(timestamp - |
138 | | - most_recent_timestamp) |
| 166 | + # NOTE: Some completion API might have a last |
| 167 | + # usage summary response without a token so we |
| 168 | + # want to check a token was generated |
| 169 | + if choices := data.get("choices"): |
| 170 | + # Note that text could be empty here |
| 171 | + # e.g. for special tokens |
| 172 | + text = choices[0].get("text") |
| 173 | + timestamp = time.perf_counter() |
| 174 | + # First token |
| 175 | + if not first_chunk_received: |
| 176 | + first_chunk_received = True |
| 177 | + ttft = time.perf_counter() - st |
| 178 | + output.ttft = ttft |
139 | 179 |
|
140 | | - most_recent_timestamp = timestamp |
141 | | - generated_text += text or "" |
142 | | - elif usage := data.get("usage"): |
143 | | - output.output_tokens = usage.get( |
144 | | - "completion_tokens") |
| 180 | + # Decoding phase |
| 181 | + else: |
| 182 | + output.itl.append(timestamp - |
| 183 | + most_recent_timestamp) |
| 184 | + |
| 185 | + most_recent_timestamp = timestamp |
| 186 | + generated_text += text or "" |
| 187 | + elif usage := data.get("usage"): |
| 188 | + output.output_tokens = usage.get( |
| 189 | + "completion_tokens") |
145 | 190 | if first_chunk_received: |
146 | 191 | output.success = True |
147 | 192 | else: |
@@ -227,41 +272,44 @@ async def async_request_openai_chat_completions( |
227 | 272 | async with session.post(url=api_url, json=payload, |
228 | 273 | headers=headers) as response: |
229 | 274 | if response.status == 200: |
230 | | - async for chunk_bytes in response.content: |
| 275 | + handler = StreamedResponseHandler() |
| 276 | + async for chunk_bytes in response.content.iter_any(): |
231 | 277 | chunk_bytes = chunk_bytes.strip() |
232 | 278 | if not chunk_bytes: |
233 | 279 | continue |
234 | | - chunk_bytes = chunk_bytes.decode("utf-8") |
235 | | - # NOTE: SSE comments (often used as pings) start with |
236 | | - # a colon. These are not JSON data payload and should |
237 | | - # be skipped. |
238 | | - if chunk_bytes.startswith(":"): |
239 | | - continue |
240 | 280 |
|
241 | | - chunk = chunk_bytes.removeprefix("data: ") |
| 281 | + messages = handler.add_chunk(chunk_bytes) |
| 282 | + for message in messages: |
| 283 | + # NOTE: SSE comments (often used as pings) start with |
| 284 | + # a colon. These are not JSON data payload and should |
| 285 | + # be skipped. |
| 286 | + if message.startswith(":"): |
| 287 | + continue |
| 288 | + |
| 289 | + chunk = message.removeprefix("data: ") |
242 | 290 |
|
243 | | - if chunk != "[DONE]": |
244 | | - timestamp = time.perf_counter() |
245 | | - data = json.loads(chunk) |
| 291 | + if chunk != "[DONE]": |
| 292 | + timestamp = time.perf_counter() |
| 293 | + data = json.loads(chunk) |
246 | 294 |
|
247 | | - if choices := data.get("choices"): |
248 | | - content = choices[0]["delta"].get("content") |
249 | | - # First token |
250 | | - if ttft == 0.0: |
251 | | - ttft = timestamp - st |
252 | | - output.ttft = ttft |
| 295 | + if choices := data.get("choices"): |
| 296 | + content = choices[0]["delta"].get("content") |
| 297 | + # First token |
| 298 | + if ttft == 0.0: |
| 299 | + ttft = timestamp - st |
| 300 | + output.ttft = ttft |
253 | 301 |
|
254 | | - # Decoding phase |
255 | | - else: |
256 | | - output.itl.append(timestamp - |
257 | | - most_recent_timestamp) |
| 302 | + # Decoding phase |
| 303 | + else: |
| 304 | + output.itl.append(timestamp - |
| 305 | + most_recent_timestamp) |
258 | 306 |
|
259 | | - generated_text += content or "" |
260 | | - elif usage := data.get("usage"): |
261 | | - output.output_tokens = usage.get( |
262 | | - "completion_tokens") |
| 307 | + generated_text += content or "" |
| 308 | + elif usage := data.get("usage"): |
| 309 | + output.output_tokens = usage.get( |
| 310 | + "completion_tokens") |
263 | 311 |
|
264 | | - most_recent_timestamp = timestamp |
| 312 | + most_recent_timestamp = timestamp |
265 | 313 |
|
266 | 314 | output.generated_text = generated_text |
267 | 315 | output.success = True |
@@ -347,36 +395,40 @@ def to_bytes(y, sr): |
347 | 395 | data=form, |
348 | 396 | headers=headers) as response: |
349 | 397 | if response.status == 200: |
350 | | - async for chunk_bytes in response.content: |
| 398 | + handler = StreamedResponseHandler() |
| 399 | + |
| 400 | + async for chunk_bytes in response.content.iter_any(): |
351 | 401 | chunk_bytes = chunk_bytes.strip() |
352 | 402 | if not chunk_bytes: |
353 | 403 | continue |
354 | 404 |
|
355 | | - chunk = chunk_bytes.decode("utf-8").removeprefix( |
356 | | - "data: ") |
357 | | - if chunk != "[DONE]": |
358 | | - timestamp = time.perf_counter() |
359 | | - data = json.loads(chunk) |
360 | | - |
361 | | - if choices := data.get("choices"): |
362 | | - content = choices[0]["delta"].get( |
363 | | - "content") |
364 | | - # First token |
365 | | - if ttft == 0.0: |
366 | | - ttft = timestamp - st |
367 | | - output.ttft = ttft |
368 | | - |
369 | | - # Decoding phase |
370 | | - else: |
371 | | - output.itl.append( |
372 | | - timestamp - most_recent_timestamp) |
373 | | - |
374 | | - generated_text += content or "" |
375 | | - elif usage := data.get("usage"): |
376 | | - output.output_tokens = usage.get( |
377 | | - "completion_tokens") |
378 | | - |
379 | | - most_recent_timestamp = timestamp |
| 405 | + messages = handler.add_chunk(chunk_bytes) |
| 406 | + for message in messages: |
| 407 | + chunk = message.decode("utf-8").removeprefix( |
| 408 | + "data: ") |
| 409 | + if chunk != "[DONE]": |
| 410 | + timestamp = time.perf_counter() |
| 411 | + data = json.loads(chunk) |
| 412 | + |
| 413 | + if choices := data.get("choices"): |
| 414 | + content = choices[0]["delta"].get( |
| 415 | + "content") |
| 416 | + # First token |
| 417 | + if ttft == 0.0: |
| 418 | + ttft = timestamp - st |
| 419 | + output.ttft = ttft |
| 420 | + |
| 421 | + # Decoding phase |
| 422 | + else: |
| 423 | + output.itl.append( |
| 424 | + timestamp - most_recent_timestamp) |
| 425 | + |
| 426 | + generated_text += content or "" |
| 427 | + elif usage := data.get("usage"): |
| 428 | + output.output_tokens = usage.get( |
| 429 | + "completion_tokens") |
| 430 | + |
| 431 | + most_recent_timestamp = timestamp |
380 | 432 |
|
381 | 433 | output.generated_text = generated_text |
382 | 434 | output.success = True |
|
0 commit comments