Skip to content

Commit 5a74724

Browse files
authored
Improve ergonomics of streaming predictions (#269)
A few folks have expressed confusion about how to stream output from a prediction. This PR updates the README with documentation and adds a missing `async_stream` instance method on `prediction`. --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 0448869 commit 5a74724

File tree

4 files changed

+64
-22
lines changed

4 files changed

+64
-22
lines changed

.github/workflows/ci.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ jobs:
1313

1414
name: "Test Python ${{ matrix.python-version }}"
1515

16+
env:
17+
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
18+
1619
timeout-minutes: 10
1720

1821
strategy:

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,19 @@ for event in replicate.stream(
114114
print(str(event), end="")
115115
```
116116
117+
You can also stream the output of a prediction you create.
118+
This is helpful when you want the ID of the prediction separate from its output.
119+
120+
```python
121+
version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3
122+
prediction = replicate.predictions.create(version=version, input={
123+
"prompt": "Please write a haiku about llamas.",
124+
})
125+
126+
for event in prediction.stream():
127+
print(str(event), end="")
128+
```
129+
117130
For more information, see
118131
["Streaming output"](https://replicate.com/docs/streaming) in Replicate's docs.
119132

replicate/prediction.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ async def async_wait(self) -> None:
149149
await asyncio.sleep(self._client.poll_interval)
150150
await self.async_reload()
151151

152-
def stream(self) -> Optional[Iterator["ServerSentEvent"]]:
152+
def stream(self) -> Iterator["ServerSentEvent"]:
153153
"""
154154
Stream the prediction output.
155155
@@ -168,6 +168,31 @@ def stream(self) -> Optional[Iterator["ServerSentEvent"]]:
168168
with self._client._client.stream("GET", url, headers=headers) as response:
169169
yield from EventSource(response)
170170

171+
async def async_stream(self) -> AsyncIterator["ServerSentEvent"]:
172+
"""
173+
Stream the prediction output asynchronously.
174+
175+
Raises:
176+
ReplicateError: If the model does not support streaming.
177+
"""
178+
179+
# no-op to enforce the use of 'await' when calling this method
180+
await asyncio.sleep(0)
181+
182+
url = self.urls and self.urls.get("stream", None)
183+
if not url or not isinstance(url, str):
184+
raise ReplicateError("Model does not support streaming")
185+
186+
headers = {}
187+
headers["Accept"] = "text/event-stream"
188+
headers["Cache-Control"] = "no-store"
189+
190+
async with self._client._async_client.stream(
191+
"GET", url, headers=headers
192+
) as response:
193+
async for event in EventSource(response):
194+
yield event
195+
171196
def cancel(self) -> None:
172197
"""
173198
Cancels a running prediction.

tests/test_stream.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,56 @@
11
import pytest
22

33
import replicate
4+
from replicate.stream import ServerSentEvent
45

56

67
@pytest.mark.asyncio
78
@pytest.mark.parametrize("async_flag", [True, False])
89
async def test_stream(async_flag, record_mode):
9-
if record_mode == "none":
10-
return
11-
12-
version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
13-
10+
model = "replicate/canary:30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272"
1411
input = {
15-
"prompt": "Please write a haiku about llamas.",
12+
"text": "Hello",
1613
}
1714

1815
events = []
1916

2017
if async_flag:
2118
async for event in await replicate.async_stream(
22-
f"meta/llama-2-70b-chat:{version}",
19+
model,
2320
input=input,
2421
):
2522
events.append(event)
2623
else:
2724
for event in replicate.stream(
28-
f"meta/llama-2-70b-chat:{version}",
25+
model,
2926
input=input,
3027
):
3128
events.append(event)
3229

3330
assert len(events) > 0
34-
assert events[0].event == "output"
31+
assert any(event.event == ServerSentEvent.EventType.OUTPUT for event in events)
32+
assert any(event.event == ServerSentEvent.EventType.DONE for event in events)
3533

3634

3735
@pytest.mark.asyncio
38-
async def test_stream_prediction(record_mode):
39-
if record_mode == "none":
40-
return
41-
42-
version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
43-
36+
@pytest.mark.parametrize("async_flag", [True, False])
37+
async def test_stream_prediction(async_flag, record_mode):
38+
version = "30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272"
4439
input = {
45-
"prompt": "Please write a haiku about llamas.",
40+
"text": "Hello",
4641
}
4742

48-
prediction = replicate.predictions.create(version=version, input=input)
49-
5043
events = []
51-
for event in prediction.stream():
52-
events.append(event)
44+
45+
if async_flag:
46+
async for event in replicate.predictions.create(
47+
version=version, input=input, stream=True
48+
).async_stream():
49+
events.append(event)
50+
else:
51+
for event in replicate.predictions.create(
52+
version=version, input=input, stream=True
53+
).stream():
54+
events.append(event)
5355

5456
assert len(events) > 0
55-
assert events[0].event == "output"

0 commit comments

Comments
 (0)