Skip to content

Commit ce20b68

Browse files
committed
feat: add support for replicate.stream()
This PR adds support for streaming predictions via the `replicate.stream()` method. Changes: - Add `stream()` method to both Replicate and AsyncReplicate clients - Add module-level `stream()` function for convenience - Create new `lib/_predictions_stream.py` module with streaming logic - Add comprehensive tests for sync and async streaming - Update README with documentation and examples using anthropic/claude-4-sonnet The stream method creates a prediction and returns an iterator that yields output chunks as they become available via Server-Sent Events (SSE). This is useful for language models where you want to display output as it's generated.
1 parent 2804bd6 commit ce20b68

File tree

6 files changed

+525
-4
lines changed

6 files changed

+525
-4
lines changed

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,18 @@ For models that support streaming (particularly language models), you can use `r
118118
import replicate
119119

120120
for event in replicate.stream(
121-
"meta/meta-llama-3-70b-instruct",
121+
"anthropic/claude-4-sonnet",
122122
input={
123-
"prompt": "Please write a haiku about llamas.",
123+
"prompt": "Give me a recipe for tasty smashed avocado on sourdough toast.",
124+
"max_tokens": 8192,
125+
"system_prompt": "You are a helpful assistant",
124126
},
125127
):
126128
print(str(event), end="")
127129
```
128130

131+
The `stream()` method creates a prediction and returns an iterator that yields output chunks as they become available via Server-Sent Events (SSE). This is useful for language models where you want to display output as it's generated rather than waiting for the entire response.
132+
129133
## Async usage
130134

131135
Simply import `AsyncReplicate` instead of `Replicate` and use `await` with each API call:
@@ -172,7 +176,11 @@ async def main():
172176

173177
# Stream a model's output
174178
async for event in replicate.stream(
175-
"meta/meta-llama-3-70b-instruct", input={"prompt": "Write a haiku about coding"}
179+
"anthropic/claude-4-sonnet",
180+
input={
181+
"prompt": "Write a haiku about coding",
182+
"system_prompt": "You are a helpful assistant",
183+
},
176184
):
177185
print(str(event), end="")
178186

src/replicate/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
if not __name.startswith("__"):
110110
try:
111111
# Skip symbols that are imported later from _module_client
112-
if __name in ("run", "use"):
112+
if __name in ("run", "use", "stream"):
113113
continue
114114
__locals[__name].__module__ = "replicate"
115115
except (TypeError, AttributeError):
@@ -253,6 +253,7 @@ def _reset_client() -> None: # type: ignore[reportUnusedFunction]
253253
use as use,
254254
files as files,
255255
models as models,
256+
stream as stream,
256257
account as account,
257258
hardware as hardware,
258259
webhooks as webhooks,

src/replicate/_client.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,54 @@ def use(
320320
# TODO: Fix mypy overload matching for streaming parameter
321321
return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return]
322322

323+
def stream(
324+
self,
325+
ref: Union[Model, Version, ModelVersionIdentifier, str],
326+
*,
327+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
328+
**params: Unpack[PredictionCreateParamsWithoutVersion],
329+
) -> Iterator[str]:
330+
"""
331+
Stream output from a model prediction.
332+
333+
This creates a prediction and returns an iterator that yields output chunks
334+
as they become available via Server-Sent Events (SSE).
335+
336+
Args:
337+
ref: Reference to the model or version to run. Can be:
338+
- A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
339+
- A string with owner/name format (e.g. "replicate/hello-world")
340+
- A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...")
341+
- A Model instance with owner and name attributes
342+
- A Version instance with id attribute
343+
- A ModelVersionIdentifier dictionary with owner, name, and/or version keys
344+
file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url"
345+
**params: Additional parameters to pass to the prediction creation endpoint including
346+
the required "input" dictionary with model-specific parameters
347+
348+
Yields:
349+
str: Output chunks from the model as they become available
350+
351+
Raises:
352+
ValueError: If the reference format is invalid or model doesn't support streaming
353+
ReplicateError: If the prediction fails
354+
355+
Example:
356+
for event in replicate.stream(
357+
"meta/meta-llama-3-70b-instruct",
358+
input={"prompt": "Write a haiku about coding"},
359+
):
360+
print(str(event), end="")
361+
"""
362+
from .lib._predictions_stream import stream
363+
364+
return stream(
365+
self,
366+
ref,
367+
file_encoding_strategy=file_encoding_strategy,
368+
**params,
369+
)
370+
323371
def copy(
324372
self,
325373
*,
@@ -695,6 +743,55 @@ def use(
695743
# TODO: Fix mypy overload matching for streaming parameter
696744
return _use(self, ref, hint=hint, streaming=streaming) # type: ignore[call-overload, no-any-return]
697745

746+
async def stream(
747+
self,
748+
ref: Union[Model, Version, ModelVersionIdentifier, str],
749+
*,
750+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
751+
**params: Unpack[PredictionCreateParamsWithoutVersion],
752+
) -> AsyncIterator[str]:
753+
"""
754+
Stream output from a model prediction asynchronously.
755+
756+
This creates a prediction and returns an async iterator that yields output chunks
757+
as they become available via Server-Sent Events (SSE).
758+
759+
Args:
760+
ref: Reference to the model or version to run. Can be:
761+
- A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
762+
- A string with owner/name format (e.g. "replicate/hello-world")
763+
- A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...")
764+
- A Model instance with owner and name attributes
765+
- A Version instance with id attribute
766+
- A ModelVersionIdentifier dictionary with owner, name, and/or version keys
767+
file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url"
768+
**params: Additional parameters to pass to the prediction creation endpoint including
769+
the required "input" dictionary with model-specific parameters
770+
771+
Yields:
772+
str: Output chunks from the model as they become available
773+
774+
Raises:
775+
ValueError: If the reference format is invalid or model doesn't support streaming
776+
ReplicateError: If the prediction fails
777+
778+
Example:
779+
async for event in replicate.stream(
780+
"meta/meta-llama-3-70b-instruct",
781+
input={"prompt": "Write a haiku about coding"},
782+
):
783+
print(str(event), end="")
784+
"""
785+
from .lib._predictions_stream import async_stream
786+
787+
async for chunk in async_stream(
788+
self,
789+
ref,
790+
file_encoding_strategy=file_encoding_strategy,
791+
**params,
792+
):
793+
yield chunk
794+
698795
def copy(
699796
self,
700797
*,

src/replicate/_module_client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __load__(self) -> PredictionsResource:
8282
__client: Replicate = cast(Replicate, {})
8383
run = __client.run
8484
use = __client.use
85+
stream = __client.stream
8586
else:
8687

8788
def _run(*args, **kwargs):
@@ -100,8 +101,12 @@ def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs):
100101

101102
return use(Replicate, ref, hint=hint, streaming=streaming, **kwargs)
102103

104+
def _stream(*args, **kwargs):
105+
return _load_client().stream(*args, **kwargs)
106+
103107
run = _run
104108
use = _use
109+
stream = _stream
105110

106111
files: FilesResource = FilesResourceProxy().__as_proxied__()
107112
models: ModelsResource = ModelsResourceProxy().__as_proxied__()
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Union, Iterator, Optional
4+
from collections.abc import AsyncIterator
5+
from typing_extensions import Unpack
6+
7+
from replicate.lib._files import FileEncodingStrategy
8+
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
9+
10+
from ..types import PredictionCreateParams
11+
from ._models import Model, Version, ModelVersionIdentifier, resolve_reference
12+
13+
if TYPE_CHECKING:
14+
from .._client import Replicate, AsyncReplicate
15+
16+
17+
def stream(
18+
client: "Replicate",
19+
ref: Union[Model, Version, ModelVersionIdentifier, str],
20+
*,
21+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
22+
**params: Unpack[PredictionCreateParamsWithoutVersion],
23+
) -> Iterator[str]:
24+
"""
25+
Stream output from a model prediction.
26+
27+
This creates a prediction and returns an iterator that yields output chunks
28+
as they become available via Server-Sent Events (SSE).
29+
30+
Args:
31+
client: The Replicate client instance
32+
ref: Reference to the model or version to run. Can be:
33+
- A string containing a version ID
34+
- A string with owner/name format (e.g. "replicate/hello-world")
35+
- A string with owner/name:version format
36+
- A Model instance
37+
- A Version instance
38+
- A ModelVersionIdentifier dictionary
39+
file_encoding_strategy: Strategy for encoding file inputs
40+
**params: Additional parameters including the required "input" dictionary
41+
42+
Yields:
43+
str: Output chunks from the model as they become available
44+
45+
Raises:
46+
ValueError: If the reference format is invalid
47+
ReplicateError: If the prediction fails or streaming is not available
48+
"""
49+
# Resolve ref to its components
50+
try:
51+
version, owner, name, version_id = resolve_reference(ref)
52+
except ValueError:
53+
# If resolution fails, treat it as a version ID if it's a string
54+
if isinstance(ref, str):
55+
version_id = ref
56+
owner = name = None
57+
else:
58+
raise
59+
60+
# Create prediction
61+
prediction = None
62+
if version_id is not None:
63+
params_with_version: PredictionCreateParams = {**params, "version": version_id}
64+
prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version)
65+
elif owner and name:
66+
prediction = client.models.predictions.create(
67+
file_encoding_strategy=file_encoding_strategy, model_owner=owner, model_name=name, **params
68+
)
69+
else:
70+
if isinstance(ref, str):
71+
params_with_version = {**params, "version": ref}
72+
prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version)
73+
else:
74+
raise ValueError(
75+
f"Invalid reference format: {ref}. Expected a model name ('owner/name'), "
76+
"a version ID, a Model object, a Version object, or a ModelVersionIdentifier."
77+
)
78+
79+
# Check if streaming URL is available
80+
if not prediction.urls or not prediction.urls.stream:
81+
raise ValueError("Model does not support streaming. The prediction URLs do not include a stream endpoint.")
82+
83+
# Make SSE request to the stream URL
84+
stream_url = prediction.urls.stream
85+
86+
with client._client.stream(
87+
"GET",
88+
stream_url,
89+
headers={
90+
"Accept": "text/event-stream",
91+
"Cache-Control": "no-store",
92+
},
93+
timeout=None, # No timeout for streaming
94+
) as response:
95+
response.raise_for_status()
96+
97+
# Parse SSE events and yield output chunks
98+
decoder = client._make_sse_decoder()
99+
for sse in decoder.iter_bytes(response.iter_bytes()):
100+
# The SSE data contains the output chunks
101+
if sse.data:
102+
yield sse.data
103+
104+
105+
async def async_stream(
106+
client: "AsyncReplicate",
107+
ref: Union[Model, Version, ModelVersionIdentifier, str],
108+
*,
109+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
110+
**params: Unpack[PredictionCreateParamsWithoutVersion],
111+
) -> AsyncIterator[str]:
112+
"""
113+
Async stream output from a model prediction.
114+
115+
This creates a prediction and returns an async iterator that yields output chunks
116+
as they become available via Server-Sent Events (SSE).
117+
118+
Args:
119+
client: The AsyncReplicate client instance
120+
ref: Reference to the model or version to run
121+
file_encoding_strategy: Strategy for encoding file inputs
122+
**params: Additional parameters including the required "input" dictionary
123+
124+
Yields:
125+
str: Output chunks from the model as they become available
126+
127+
Raises:
128+
ValueError: If the reference format is invalid
129+
ReplicateError: If the prediction fails or streaming is not available
130+
"""
131+
# Resolve ref to its components
132+
try:
133+
version, owner, name, version_id = resolve_reference(ref)
134+
except ValueError:
135+
# If resolution fails, treat it as a version ID if it's a string
136+
if isinstance(ref, str):
137+
version_id = ref
138+
owner = name = None
139+
else:
140+
raise
141+
142+
# Create prediction
143+
prediction = None
144+
if version_id is not None:
145+
params_with_version: PredictionCreateParams = {**params, "version": version_id}
146+
prediction = await client.predictions.create(
147+
file_encoding_strategy=file_encoding_strategy, **params_with_version
148+
)
149+
elif owner and name:
150+
prediction = await client.models.predictions.create(
151+
file_encoding_strategy=file_encoding_strategy, model_owner=owner, model_name=name, **params
152+
)
153+
else:
154+
if isinstance(ref, str):
155+
params_with_version = {**params, "version": ref}
156+
prediction = await client.predictions.create(
157+
file_encoding_strategy=file_encoding_strategy, **params_with_version
158+
)
159+
else:
160+
raise ValueError(
161+
f"Invalid reference format: {ref}. Expected a model name ('owner/name'), "
162+
"a version ID, a Model object, a Version object, or a ModelVersionIdentifier."
163+
)
164+
165+
# Check if streaming URL is available
166+
if not prediction.urls or not prediction.urls.stream:
167+
raise ValueError("Model does not support streaming. The prediction URLs do not include a stream endpoint.")
168+
169+
# Make SSE request to the stream URL
170+
stream_url = prediction.urls.stream
171+
172+
async with client._client.stream(
173+
"GET",
174+
stream_url,
175+
headers={
176+
"Accept": "text/event-stream",
177+
"Cache-Control": "no-store",
178+
},
179+
timeout=None, # No timeout for streaming
180+
) as response:
181+
response.raise_for_status()
182+
183+
# Parse SSE events and yield output chunks
184+
decoder = client._make_sse_decoder()
185+
async for sse in decoder.aiter_bytes(response.aiter_bytes()):
186+
# The SSE data contains the output chunks
187+
if sse.data:
188+
yield sse.data

0 commit comments

Comments
 (0)