Skip to content

Commit 1ac6c47

Browse files
Merge branch 'fix/streaming-response'
- Thread-safe BaseStreamResponse with sync/async mode - Fix _stream_generate deadlock (async→sync) - Refactor OpenAI provider response processing
2 parents 2484b71 + b3ca0d8 commit 1ac6c47

File tree

3 files changed

+288
-178
lines changed

3 files changed

+288
-178
lines changed

src/msgflux/_private/response.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
2-
from typing import Any, AsyncGenerator, Union
2+
import threading
3+
from collections import deque
4+
from typing import Any, AsyncGenerator, Literal, Union
35

46

57
class CoreResponse:
@@ -29,24 +31,52 @@ def consume(self) -> Any:
2931

3032

3133
class BaseStreamResponse(CoreResponse):
32-
def __init__(self):
33-
self.first_chunk_event = asyncio.Event()
34+
def __init__(self, mode: Literal["sync", "async"] = "sync"):
35+
if mode not in {"sync", "async"}:
36+
raise ValueError("`mode` must be `sync` or `async`")
37+
self.mode = mode
38+
if mode == "async":
39+
self.first_chunk_event = asyncio.Event()
40+
else:
41+
self.first_chunk_event = threading.Event()
3442
self.data = None
35-
self.queue = asyncio.Queue()
43+
self._queue = None
44+
self._queue_loop = None
45+
self._pending_chunks = deque()
46+
self._queue_lock = threading.Lock()
3647
self.metadata = None
3748
self.response_type = None
3849

3950
def add(self, data: Any):
40-
"""Add data to the stream queue (async)."""
41-
self.queue.put_nowait(data)
51+
"""Add data to the stream queue in a thread-safe way."""
52+
with self._queue_lock:
53+
queue = self._queue
54+
loop = self._queue_loop
55+
if queue is None or loop is None or loop.is_closed():
56+
self._pending_chunks.append(data)
57+
return
58+
59+
loop.call_soon_threadsafe(queue.put_nowait, data)
60+
61+
def _bind_consumer_queue(self) -> asyncio.Queue:
62+
loop = asyncio.get_running_loop()
63+
with self._queue_lock:
64+
if self._queue is None:
65+
self._queue = asyncio.Queue()
66+
self._queue_loop = loop
67+
while self._pending_chunks:
68+
self._queue.put_nowait(self._pending_chunks.popleft())
69+
elif self._queue_loop is not loop:
70+
raise RuntimeError(
71+
"BaseStreamResponse.consume() must run on the same event loop."
72+
)
73+
return self._queue
4274

4375
async def consume(self) -> AsyncGenerator[Union[bytes, str], None]:
4476
"""Async generator that yields chunks from the queue until None is received."""
77+
queue = self._bind_consumer_queue()
4578
while True:
46-
try:
47-
chunk = await asyncio.wait_for(self.queue.get(), timeout=1.0)
48-
if chunk is None:
49-
break
50-
yield chunk
51-
except asyncio.TimeoutError:
52-
continue
79+
chunk = await queue.get()
80+
if chunk is None:
81+
break
82+
yield chunk

0 commit comments

Comments
 (0)