33from __future__ import annotations
44
55import os
6- from typing import TYPE_CHECKING , Any , Union , Mapping
7- from typing_extensions import Self , override
6+ from typing import TYPE_CHECKING , Any , Union , Mapping , Optional
7+ from typing_extensions import Self , Unpack , override
88
99import httpx
1010
11+ from replicate .lib ._files import FileEncodingStrategy
12+ from replicate .lib ._predictions import Model , Version , ModelVersionIdentifier
13+ from replicate .types .prediction_create_params import PredictionCreateParamsWithoutVersion
14+
1115from . import _exceptions
1216from ._qs import Querystring
1317from ._types import (
@@ -171,6 +175,10 @@ def with_raw_response(self) -> ReplicateWithRawResponse:
171175 def with_streaming_response (self ) -> ReplicateWithStreamedResponse :
172176 return ReplicateWithStreamedResponse (self )
173177
178+ @cached_property
179+ def poll_interval (self ) -> float :
180+ return float (os .environ .get ("REPLICATE_POLL_INTERVAL" , "0.5" ))
181+
174182 @property
175183 @override
176184 def qs (self ) -> Querystring :
@@ -191,6 +199,54 @@ def default_headers(self) -> dict[str, str | Omit]:
191199 ** self ._custom_headers ,
192200 }
193201
202+ def run (
203+ self ,
204+ ref : Union [Model , Version , ModelVersionIdentifier , str ],
205+ * ,
206+ file_encoding_strategy : Optional ["FileEncodingStrategy" ] = None ,
207+ use_file_output : bool = True ,
208+ wait : Union [int , bool , NotGiven ] = NOT_GIVEN ,
209+ ** params : Unpack [PredictionCreateParamsWithoutVersion ],
210+ ) -> Any :
211+ """
212+ Run a model prediction.
213+
214+ Args:
215+ ref: Reference to the model or version to run. Can be:
216+ - A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
217+ - A string with owner/name format (e.g. "replicate/hello-world")
218+ - A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...")
219+ - A Model instance with owner and name attributes
220+ - A Version instance with id attribute
221+ - A ModelVersionIdentifier dictionary with owner, name, and/or version keys
222+ file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url"
223+ use_file_output: If True (default), convert output URLs to FileOutput objects
224+ wait: If True (default), wait for the prediction to complete. If False, return immediately.
225+ If an integer, wait up to that many seconds.
226+ **params: Additional parameters to pass to the prediction creation endpoint including
227+ the required "input" dictionary with model-specific parameters
228+
229+ Returns:
230+ The prediction output, which could be a basic type (str, int, etc.), a FileOutput object,
231+ a list of FileOutput objects, or a dictionary of FileOutput objects, depending on what
232+ the model returns.
233+
234+ Raises:
235+ ModelError: If the model run fails
236+ ValueError: If the reference format is invalid
237+ TypeError: If both wait and prefer parameters are provided
238+ """
239+ from .lib ._predictions import run
240+
241+ return run (
242+ self ,
243+ ref ,
244+ wait = wait ,
245+ use_file_output = use_file_output ,
246+ file_encoding_strategy = file_encoding_strategy ,
247+ ** params ,
248+ )
249+
194250 def copy (
195251 self ,
196252 * ,
@@ -393,6 +449,10 @@ def with_raw_response(self) -> AsyncReplicateWithRawResponse:
393449 def with_streaming_response (self ) -> AsyncReplicateWithStreamedResponse :
394450 return AsyncReplicateWithStreamedResponse (self )
395451
452+ @cached_property
453+ def poll_interval (self ) -> float :
454+ return float (os .environ .get ("REPLICATE_POLL_INTERVAL" , "0.5" ))
455+
396456 @property
397457 @override
398458 def qs (self ) -> Querystring :
@@ -413,6 +473,54 @@ def default_headers(self) -> dict[str, str | Omit]:
413473 ** self ._custom_headers ,
414474 }
415475
476+ async def run (
477+ self ,
478+ ref : Union [Model , Version , ModelVersionIdentifier , str ],
479+ * ,
480+ use_file_output : bool = True ,
481+ file_encoding_strategy : Optional ["FileEncodingStrategy" ] = None ,
482+ wait : Union [int , bool , NotGiven ] = NOT_GIVEN ,
483+ ** params : Unpack [PredictionCreateParamsWithoutVersion ],
484+ ) -> Any :
485+ """
486+ Run a model prediction asynchronously.
487+
488+ Args:
489+ ref: Reference to the model or version to run. Can be:
490+ - A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
491+ - A string with owner/name format (e.g. "replicate/hello-world")
492+ - A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...")
493+ - A Model instance with owner and name attributes
494+ - A Version instance with id attribute
495+ - A ModelVersionIdentifier dictionary with owner, name, and/or version keys
496+ use_file_output: If True (default), convert output URLs to AsyncFileOutput objects
497+ file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url"
498+ wait: If True (default), wait for the prediction to complete. If False, return immediately.
499+ If an integer, wait up to that many seconds.
500+ **params: Additional parameters to pass to the prediction creation endpoint including
501+ the required "input" dictionary with model-specific parameters
502+
503+ Returns:
504+ The prediction output, which could be a basic type (str, int, etc.), an AsyncFileOutput object,
505+ a list of AsyncFileOutput objects, or a dictionary of AsyncFileOutput objects, depending on what
506+ the model returns.
507+
508+ Raises:
509+ ModelError: If the model run fails
510+ ValueError: If the reference format is invalid
511+ TypeError: If both wait and prefer parameters are provided
512+ """
513+ from .lib ._predictions import async_run
514+
515+ return await async_run (
516+ self ,
517+ ref ,
518+ wait = wait ,
519+ use_file_output = use_file_output ,
520+ file_encoding_strategy = file_encoding_strategy ,
521+ ** params ,
522+ )
523+
416524 def copy (
417525 self ,
418526 * ,
0 commit comments