11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , Dict , Union , Iterable
3+ from typing import TYPE_CHECKING , Dict , Union , Iterable , Optional
44from typing_extensions import Unpack
55
66from replicate .types .prediction_create_params import PredictionCreateParamsWithoutVersion
77
88from ..types import PredictionOutput , PredictionCreateParams
99from .._types import NOT_GIVEN , NotGiven
1010from .._utils import is_given
11- from .. _client import ReplicateClient , AsyncReplicateClient
11+ from ._models import Model , Version , ModelVersionIdentifier , resolve_reference
1212from .._exceptions import ModelError
1313
1414if TYPE_CHECKING :
1515 from ._files import FileOutput
16+ from .._client import ReplicateClient , AsyncReplicateClient
1617
1718
1819def run (
19- client : ReplicateClient ,
20- ref : str ,
21- # TODO: support these types
22- # ref: Union["Model", "Version", "ModelVersionIdentifier", str],
20+ client : "ReplicateClient" ,
21+ ref : Union [Model , Version , ModelVersionIdentifier , str ],
2322 * ,
2423 wait : Union [int , bool , NotGiven ] = NOT_GIVEN ,
25- # use_file_output : Optional[bool] = True,
24+ _use_file_output : Optional [bool ] = True ,
2625 ** params : Unpack [PredictionCreateParamsWithoutVersion ],
2726) -> PredictionOutput | FileOutput | Iterable [FileOutput ] | Dict [str , FileOutput ]:
27+ """
28+ Run a model prediction.
29+
30+ Args:
31+ client: The ReplicateClient instance to use for API calls
32+ ref: Reference to the model or version to run. Can be:
33+ - A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
34+ - A string with owner/name format (e.g. "replicate/hello-world")
35+ - A string with owner/name/version format (e.g. "replicate/hello-world/5c7d5dc6...")
36+ - A Model instance with owner and name attributes
37+ - A Version instance with id attribute
38+ - A ModelVersionIdentifier dictionary with owner, name, and/or version keys
39+ input: Dictionary of input parameters for the model
40+ wait: If True (default), wait for the prediction to complete. If False, return immediately.
41+ If an integer, wait up to that many seconds.
42+ use_file_output: If True (default), convert output URLs to FileOutput objects
43+ **params: Additional parameters to pass to the prediction creation endpoint
44+
45+ Returns:
46+ The prediction output, which could be a basic type (str, int, etc.), a FileOutput object,
47+ a list of FileOutput objects, or a dictionary of FileOutput objects, depending on what
48+ the model returns.
49+
50+ Raises:
51+ ModelError: If the model run fails
52+ ValueError: If the reference format is invalid
53+ TypeError: If both wait and prefer parameters are provided
54+ """
2855 from ._files import transform_output
2956
3057 if is_given (wait ) and "prefer" in params :
@@ -41,9 +68,27 @@ def run(
4168 else :
4269 params .setdefault ("prefer" , f"wait={ wait } " )
4370
44- # TODO: support more ref types
45- params_with_version : PredictionCreateParams = {** params , "version" : ref }
46- prediction = client .predictions .create (** params_with_version )
71+ # Resolve ref to its components
72+ _version , owner , name , version_id = resolve_reference (ref )
73+
74+ prediction = None
75+ if version_id is not None :
76+ # Create prediction with the specific version ID
77+ params_with_version : PredictionCreateParams = {** params , "version" : version_id }
78+ prediction = client .predictions .create (** params_with_version )
79+ elif owner and name :
80+ # Create prediction via models resource with owner/name
81+ prediction = client .models .predictions .create (model_owner = owner , model_name = name , ** params )
82+ else :
83+ # If ref is a string but doesn't match expected patterns
84+ if isinstance (ref , str ):
85+ params_with_version = {** params , "version" : ref }
86+ prediction = client .predictions .create (** params_with_version )
87+ else :
88+ raise ValueError (
89+ f"Invalid reference format: { ref } . Expected a model name ('owner/name'), "
90+ "a version ID, a Model object, a Version object, or a ModelVersionIdentifier."
91+ )
4792
4893 # Currently the "Prefer: wait" interface will return a prediction with a status
4994 # of "processing" rather than a terminal state because it returns before the
@@ -68,15 +113,41 @@ def run(
68113
69114
70115async def async_run (
71- client : AsyncReplicateClient ,
72- ref : str ,
73- # TODO: support these types
74- # ref: Union["Model", "Version", "ModelVersionIdentifier", str],
116+ client : "AsyncReplicateClient" ,
117+ ref : Union [Model , Version , ModelVersionIdentifier , str ],
75118 * ,
76119 wait : Union [int , bool , NotGiven ] = NOT_GIVEN ,
77- # use_file_output : Optional[bool] = True,
120+ _use_file_output : Optional [bool ] = True ,
78121 ** params : Unpack [PredictionCreateParamsWithoutVersion ],
79122) -> PredictionOutput | FileOutput | Iterable [FileOutput ] | Dict [str , FileOutput ]:
123+ """
124+ Run a model prediction asynchronously.
125+
126+ Args:
127+ client: The AsyncReplicateClient instance to use for API calls
128+ ref: Reference to the model or version to run. Can be:
129+ - A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
130+ - A string with owner/name format (e.g. "replicate/hello-world")
131+ - A string with owner/name/version format (e.g. "replicate/hello-world/5c7d5dc6...")
132+ - A Model instance with owner and name attributes
133+ - A Version instance with id attribute
134+ - A ModelVersionIdentifier dictionary with owner, name, and/or version keys
135+ input: Dictionary of input parameters for the model
136+ wait: If True (default), wait for the prediction to complete. If False, return immediately.
137+ If an integer, wait up to that many seconds.
138+ use_file_output: If True (default), convert output URLs to AsyncFileOutput objects
139+ **params: Additional parameters to pass to the prediction creation endpoint
140+
141+ Returns:
142+ The prediction output, which could be a basic type (str, int, etc.), an AsyncFileOutput object,
143+ a list of AsyncFileOutput objects, or a dictionary of AsyncFileOutput objects, depending on what
144+ the model returns.
145+
146+ Raises:
147+ ModelError: If the model run fails
148+ ValueError: If the reference format is invalid
149+ TypeError: If both wait and prefer parameters are provided
150+ """
80151 from ._files import transform_output
81152
82153 if is_given (wait ) and "prefer" in params :
@@ -93,9 +164,27 @@ async def async_run(
93164 else :
94165 params .setdefault ("prefer" , f"wait={ wait } " )
95166
96- # TODO: support more ref types
97- params_with_version : PredictionCreateParams = {** params , "version" : ref }
98- prediction = await client .predictions .create (** params_with_version )
167+ # Resolve ref to its components
168+ _version , owner , name , version_id = resolve_reference (ref )
169+
170+ prediction = None
171+ if version_id is not None :
172+ # Create prediction with the specific version ID
173+ params_with_version : PredictionCreateParams = {** params , "version" : version_id }
174+ prediction = await client .predictions .create (** params_with_version )
175+ elif owner and name :
176+ # Create prediction via models resource with owner/name
177+ prediction = await client .models .predictions .create (model_owner = owner , model_name = name , ** params )
178+ else :
179+ # If ref is a string but doesn't match expected patterns
180+ if isinstance (ref , str ):
181+ params_with_version = {** params , "version" : ref }
182+ prediction = await client .predictions .create (** params_with_version )
183+ else :
184+ raise ValueError (
185+ f"Invalid reference format: { ref } . Expected a model name ('owner/name'), "
186+ "a version ID, a Model object, a Version object, or a ModelVersionIdentifier."
187+ )
99188
100189 # Currently the "Prefer: wait" interface will return a prediction with a status
101190 # of "processing" rather than a terminal state because it returns before the
0 commit comments