Skip to content

Commit f2d2683

Browse files
committed
refactor: DRY up duplicate docstrings in stream functions
1 parent 9776bfd commit f2d2683

File tree

2 files changed

+37
-90
lines changed

2 files changed

+37
-90
lines changed

src/replicate/_client.py

Lines changed: 10 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -330,34 +330,16 @@ def stream(
330330
"""
331331
Stream output from a model prediction.
332332
333-
This creates a prediction and returns an iterator that yields output chunks
334-
as strings as they become available from the streaming API.
335-
336-
Args:
337-
ref: Reference to the model or version to run. Can be:
338-
- A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
339-
- A string with owner/name format (e.g. "replicate/hello-world")
340-
- A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...")
341-
- A Model instance with owner and name attributes
342-
- A Version instance with id attribute
343-
- A ModelVersionIdentifier dictionary with owner, name, and/or version keys
344-
file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url"
345-
**params: Additional parameters to pass to the prediction creation endpoint including
346-
the required "input" dictionary with model-specific parameters
347-
348-
Yields:
349-
str: Output chunks from the model as they become available
350-
351-
Raises:
352-
ValueError: If the reference format is invalid or model doesn't support streaming
353-
ReplicateError: If the prediction fails
354-
355333
Example:
356-
for event in replicate.stream(
334+
```python
335+
for event in client.stream(
357336
"meta/meta-llama-3-70b-instruct",
358337
input={"prompt": "Write a haiku about coding"},
359338
):
360339
print(str(event), end="")
340+
```
341+
342+
See `replicate.lib._predictions_stream.stream` for full documentation.
361343
"""
362344
from .lib._predictions_stream import stream
363345

@@ -753,34 +735,16 @@ async def stream(
753735
"""
754736
Stream output from a model prediction asynchronously.
755737
756-
This creates a prediction and returns an async iterator that yields output chunks
757-
as strings as they become available from the streaming API.
758-
759-
Args:
760-
ref: Reference to the model or version to run. Can be:
761-
- A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
762-
- A string with owner/name format (e.g. "replicate/hello-world")
763-
- A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...")
764-
- A Model instance with owner and name attributes
765-
- A Version instance with id attribute
766-
- A ModelVersionIdentifier dictionary with owner, name, and/or version keys
767-
file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url"
768-
**params: Additional parameters to pass to the prediction creation endpoint including
769-
the required "input" dictionary with model-specific parameters
770-
771-
Yields:
772-
str: Output chunks from the model as they become available
773-
774-
Raises:
775-
ValueError: If the reference format is invalid or model doesn't support streaming
776-
ReplicateError: If the prediction fails
777-
778738
Example:
779-
async for event in replicate.stream(
739+
```python
740+
async for event in client.stream(
780741
"meta/meta-llama-3-70b-instruct",
781742
input={"prompt": "Write a haiku about coding"},
782743
):
783744
print(str(event), end="")
745+
```
746+
747+
See `replicate.lib._predictions_stream.async_stream` for full documentation.
784748
"""
785749
from .lib._predictions_stream import async_stream
786750

src/replicate/lib/_predictions_stream.py

Lines changed: 27 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,31 @@
1313
if TYPE_CHECKING:
1414
from .._client import Replicate, AsyncReplicate
1515

16+
_STREAM_DOCSTRING = """
17+
Stream output from a model prediction.
18+
19+
This creates a prediction and returns an iterator that yields output chunks
20+
as strings as they become available from the streaming API.
21+
22+
Args:
23+
ref: Reference to the model or version to run. Can be:
24+
- A string containing a version ID
25+
- A string with owner/name format (e.g. "replicate/hello-world")
26+
- A string with owner/name:version format
27+
- A Model instance
28+
- A Version instance
29+
- A ModelVersionIdentifier dictionary
30+
file_encoding_strategy: Strategy for encoding file inputs
31+
**params: Additional parameters including the required "input" dictionary
32+
33+
Yields:
34+
str: Output chunks from the model as they become available
35+
36+
Raises:
37+
ValueError: If the reference format is invalid
38+
ReplicateError: If the prediction fails or streaming is not available
39+
"""
40+
1641

1742
def _resolve_reference(
1843
ref: Union[Model, Version, ModelVersionIdentifier, str],
@@ -35,31 +60,7 @@ def stream(
3560
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
3661
**params: Unpack[PredictionCreateParamsWithoutVersion],
3762
) -> Iterator[str]:
38-
"""
39-
Stream output from a model prediction.
40-
41-
This creates a prediction and returns an iterator that yields output chunks
42-
as strings as they become available from the streaming API.
43-
44-
Args:
45-
client: The Replicate client instance
46-
ref: Reference to the model or version to run. Can be:
47-
- A string containing a version ID
48-
- A string with owner/name format (e.g. "replicate/hello-world")
49-
- A string with owner/name:version format
50-
- A Model instance
51-
- A Version instance
52-
- A ModelVersionIdentifier dictionary
53-
file_encoding_strategy: Strategy for encoding file inputs
54-
**params: Additional parameters including the required "input" dictionary
55-
56-
Yields:
57-
str: Output chunks from the model as they become available
58-
59-
Raises:
60-
ValueError: If the reference format is invalid
61-
ReplicateError: If the prediction fails or streaming is not available
62-
"""
63+
__doc__ = _STREAM_DOCSTRING
6364
version, owner, name, version_id = _resolve_reference(ref)
6465

6566
# Create prediction
@@ -112,25 +113,7 @@ async def async_stream(
112113
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
113114
**params: Unpack[PredictionCreateParamsWithoutVersion],
114115
) -> AsyncIterator[str]:
115-
"""
116-
Async stream output from a model prediction.
117-
118-
This creates a prediction and returns an async iterator that yields output chunks
119-
as strings as they become available from the streaming API.
120-
121-
Args:
122-
client: The AsyncReplicate client instance
123-
ref: Reference to the model or version to run
124-
file_encoding_strategy: Strategy for encoding file inputs
125-
**params: Additional parameters including the required "input" dictionary
126-
127-
Yields:
128-
str: Output chunks from the model as they become available
129-
130-
Raises:
131-
ValueError: If the reference format is invalid
132-
ReplicateError: If the prediction fails or streaming is not available
133-
"""
116+
__doc__ = _STREAM_DOCSTRING
134117
version, owner, name, version_id = _resolve_reference(ref)
135118

136119
# Create prediction

0 commit comments

Comments
 (0)