Skip to content

Commit 9776bfd

Browse files
committed
refactor: DRY up duplicate reference resolution logic in stream functions
1 parent 5e6be60 commit 9776bfd

File tree

1 file changed

+17
-25
lines changed

1 file changed

+17
-25
lines changed

src/replicate/lib/_predictions_stream.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Union, Iterator, Optional
3+
from typing import TYPE_CHECKING, Tuple, Union, Iterator, Optional
44
from collections.abc import AsyncIterator
55
from typing_extensions import Unpack
66

@@ -14,6 +14,20 @@
1414
from .._client import Replicate, AsyncReplicate
1515

1616

17+
def _resolve_reference(
18+
ref: Union[Model, Version, ModelVersionIdentifier, str],
19+
) -> Tuple[Optional[Version], Optional[str], Optional[str], Optional[str]]:
20+
"""Resolve a model reference to its components, with fallback for plain version IDs."""
21+
try:
22+
return resolve_reference(ref)
23+
except ValueError:
24+
# If resolution fails, treat it as a version ID if it's a string
25+
if isinstance(ref, str):
26+
return None, None, None, ref
27+
else:
28+
raise
29+
30+
1731
def stream(
1832
client: "Replicate",
1933
ref: Union[Model, Version, ModelVersionIdentifier, str],
@@ -46,19 +60,9 @@ def stream(
4660
ValueError: If the reference format is invalid
4761
ReplicateError: If the prediction fails or streaming is not available
4862
"""
49-
# Resolve ref to its components
50-
try:
51-
version, owner, name, version_id = resolve_reference(ref)
52-
except ValueError:
53-
# If resolution fails, treat it as a version ID if it's a string
54-
if isinstance(ref, str):
55-
version_id = ref
56-
owner = name = None
57-
else:
58-
raise
63+
version, owner, name, version_id = _resolve_reference(ref)
5964

6065
# Create prediction
61-
prediction = None
6266
if version_id is not None:
6367
params_with_version: PredictionCreateParams = {**params, "version": version_id}
6468
prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version)
@@ -80,7 +84,6 @@ def stream(
8084
if not prediction.urls or not prediction.urls.stream:
8185
raise ValueError("Model does not support streaming. The prediction URLs do not include a stream endpoint.")
8286

83-
# Make SSE request to the stream URL
8487
stream_url = prediction.urls.stream
8588

8689
with client._client.stream(
@@ -128,19 +131,9 @@ async def async_stream(
128131
ValueError: If the reference format is invalid
129132
ReplicateError: If the prediction fails or streaming is not available
130133
"""
131-
# Resolve ref to its components
132-
try:
133-
version, owner, name, version_id = resolve_reference(ref)
134-
except ValueError:
135-
# If resolution fails, treat it as a version ID if it's a string
136-
if isinstance(ref, str):
137-
version_id = ref
138-
owner = name = None
139-
else:
140-
raise
134+
version, owner, name, version_id = _resolve_reference(ref)
141135

142136
# Create prediction
143-
prediction = None
144137
if version_id is not None:
145138
params_with_version: PredictionCreateParams = {**params, "version": version_id}
146139
prediction = await client.predictions.create(
@@ -166,7 +159,6 @@ async def async_stream(
166159
if not prediction.urls or not prediction.urls.stream:
167160
raise ValueError("Model does not support streaming. The prediction URLs do not include a stream endpoint.")
168161

169-
# Make SSE request to the stream URL
170162
stream_url = prediction.urls.stream
171163

172164
async with client._client.stream(

0 commit comments

Comments
 (0)