11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , Union , Iterator , Optional
3+ from typing import TYPE_CHECKING , Tuple , Union , Iterator , Optional
44from collections .abc import AsyncIterator
55from typing_extensions import Unpack
66
1414 from .._client import Replicate , AsyncReplicate
1515
1616
17+ def _resolve_reference (
18+ ref : Union [Model , Version , ModelVersionIdentifier , str ],
19+ ) -> Tuple [Optional [Version ], Optional [str ], Optional [str ], Optional [str ]]:
20+ """Resolve a model reference to its components, with fallback for plain version IDs."""
21+ try :
22+ return resolve_reference (ref )
23+ except ValueError :
24+ # If resolution fails, treat it as a version ID if it's a string
25+ if isinstance (ref , str ):
26+ return None , None , None , ref
27+ else :
28+ raise
29+
30+
1731def stream (
1832 client : "Replicate" ,
1933 ref : Union [Model , Version , ModelVersionIdentifier , str ],
@@ -46,19 +60,9 @@ def stream(
4660 ValueError: If the reference format is invalid
4761 ReplicateError: If the prediction fails or streaming is not available
4862 """
49- # Resolve ref to its components
50- try :
51- version , owner , name , version_id = resolve_reference (ref )
52- except ValueError :
53- # If resolution fails, treat it as a version ID if it's a string
54- if isinstance (ref , str ):
55- version_id = ref
56- owner = name = None
57- else :
58- raise
63+ version , owner , name , version_id = _resolve_reference (ref )
5964
6065 # Create prediction
61- prediction = None
6266 if version_id is not None :
6367 params_with_version : PredictionCreateParams = {** params , "version" : version_id }
6468 prediction = client .predictions .create (file_encoding_strategy = file_encoding_strategy , ** params_with_version )
@@ -80,7 +84,6 @@ def stream(
8084 if not prediction .urls or not prediction .urls .stream :
8185 raise ValueError ("Model does not support streaming. The prediction URLs do not include a stream endpoint." )
8286
83- # Make SSE request to the stream URL
8487 stream_url = prediction .urls .stream
8588
8689 with client ._client .stream (
@@ -128,19 +131,9 @@ async def async_stream(
128131 ValueError: If the reference format is invalid
129132 ReplicateError: If the prediction fails or streaming is not available
130133 """
131- # Resolve ref to its components
132- try :
133- version , owner , name , version_id = resolve_reference (ref )
134- except ValueError :
135- # If resolution fails, treat it as a version ID if it's a string
136- if isinstance (ref , str ):
137- version_id = ref
138- owner = name = None
139- else :
140- raise
134+ version , owner , name , version_id = _resolve_reference (ref )
141135
142136 # Create prediction
143- prediction = None
144137 if version_id is not None :
145138 params_with_version : PredictionCreateParams = {** params , "version" : version_id }
146139 prediction = await client .predictions .create (
@@ -166,7 +159,6 @@ async def async_stream(
166159 if not prediction .urls or not prediction .urls .stream :
167160 raise ValueError ("Model does not support streaming. The prediction URLs do not include a stream endpoint." )
168161
169- # Make SSE request to the stream URL
170162 stream_url = prediction .urls .stream
171163
172164 async with client ._client .stream (
0 commit comments