|
1 | 1 | import asyncio |
2 | | -from typing import Any, AsyncGenerator, Union |
| 2 | +import threading |
| 3 | +from collections import deque |
| 4 | +from typing import Any, AsyncGenerator, Literal, Union |
3 | 5 |
|
4 | 6 |
|
5 | 7 | class CoreResponse: |
@@ -29,24 +31,52 @@ def consume(self) -> Any: |
29 | 31 |
|
30 | 32 |
|
31 | 33 | class BaseStreamResponse(CoreResponse): |
32 | | - def __init__(self): |
33 | | - self.first_chunk_event = asyncio.Event() |
| 34 | + def __init__(self, mode: Literal["sync", "async"] = "sync"): |
| 35 | + if mode not in {"sync", "async"}: |
| 36 | + raise ValueError("`mode` must be `sync` or `async`") |
| 37 | + self.mode = mode |
| 38 | + if mode == "async": |
| 39 | + self.first_chunk_event = asyncio.Event() |
| 40 | + else: |
| 41 | + self.first_chunk_event = threading.Event() |
34 | 42 | self.data = None |
35 | | - self.queue = asyncio.Queue() |
| 43 | + self._queue = None |
| 44 | + self._queue_loop = None |
| 45 | + self._pending_chunks = deque() |
| 46 | + self._queue_lock = threading.Lock() |
36 | 47 | self.metadata = None |
37 | 48 | self.response_type = None |
38 | 49 |
|
39 | 50 | def add(self, data: Any): |
40 | | - """Add data to the stream queue (async).""" |
41 | | - self.queue.put_nowait(data) |
| 51 | + """Add data to the stream queue in a thread-safe way.""" |
| 52 | + with self._queue_lock: |
| 53 | + queue = self._queue |
| 54 | + loop = self._queue_loop |
| 55 | + if queue is None or loop is None or loop.is_closed(): |
| 56 | + self._pending_chunks.append(data) |
| 57 | + return |
| 58 | + |
| 59 | + loop.call_soon_threadsafe(queue.put_nowait, data) |
| 60 | + |
| 61 | + def _bind_consumer_queue(self) -> asyncio.Queue: |
| 62 | + loop = asyncio.get_running_loop() |
| 63 | + with self._queue_lock: |
| 64 | + if self._queue is None: |
| 65 | + self._queue = asyncio.Queue() |
| 66 | + self._queue_loop = loop |
| 67 | + while self._pending_chunks: |
| 68 | + self._queue.put_nowait(self._pending_chunks.popleft()) |
| 69 | + elif self._queue_loop is not loop: |
| 70 | + raise RuntimeError( |
| 71 | + "BaseStreamResponse.consume() must run on the same event loop." |
| 72 | + ) |
| 73 | + return self._queue |
42 | 74 |
|
43 | 75 | async def consume(self) -> AsyncGenerator[Union[bytes, str], None]: |
44 | 76 | """Async generator that yields chunks from the queue until None is received.""" |
| 77 | + queue = self._bind_consumer_queue() |
45 | 78 | while True: |
46 | | - try: |
47 | | - chunk = await asyncio.wait_for(self.queue.get(), timeout=1.0) |
48 | | - if chunk is None: |
49 | | - break |
50 | | - yield chunk |
51 | | - except asyncio.TimeoutError: |
52 | | - continue |
| 79 | + chunk = await queue.get() |
| 80 | + if chunk is None: |
| 81 | + break |
| 82 | + yield chunk |
0 commit comments