1111 List ,
1212 Literal ,
1313 Optional ,
14+ Tuple ,
1415 Union ,
16+ overload ,
1517)
1618
1719from typing_extensions import NotRequired , TypedDict , Unpack
3133
3234if TYPE_CHECKING :
3335 from replicate .client import Client
36+ from replicate .deployment import Deployment
37+ from replicate .model import Model
3438 from replicate .stream import ServerSentEvent
3539
3640
@@ -380,21 +384,82 @@ class CreatePredictionParams(TypedDict):
380384 stream : NotRequired [bool ]
381385 """Enable streaming of prediction output."""
382386
387+ @overload
383388 def create (
384389 self ,
385390 version : Union [Version , str ],
386391 input : Optional [Dict [str , Any ]],
387392 ** params : Unpack ["Predictions.CreatePredictionParams" ],
393+ ) -> Prediction : ...
394+
395+ @overload
396+ def create (
397+ self ,
398+ * ,
399+ model : Union [str , Tuple [str , str ], "Model" ],
400+ input : Optional [Dict [str , Any ]],
401+ ** params : Unpack ["Predictions.CreatePredictionParams" ],
402+ ) -> Prediction : ...
403+
404+ @overload
405+ def create (
406+ self ,
407+ * ,
408+ deployment : Union [str , Tuple [str , str ], "Deployment" ],
409+ input : Optional [Dict [str , Any ]],
410+ ** params : Unpack ["Predictions.CreatePredictionParams" ],
411+ ) -> Prediction : ...
412+
413+ def create ( # type: ignore
414+ self ,
415+ * args ,
416+ model : Optional [Union [str , Tuple [str , str ], "Model" ]] = None ,
417+ version : Optional [Union [Version , str , "Version" ]] = None ,
418+ deployment : Optional [Union [str , Tuple [str , str ], "Deployment" ]] = None ,
419+ input : Optional [Dict [str , Any ]] = None ,
420+ ** params : Unpack ["Predictions.CreatePredictionParams" ],
388421 ) -> Prediction :
389422 """
390- Create a new prediction for the specified model version.
423+ Create a new prediction for the specified model, version, or deployment .
391424 """
392425
426+ if args :
427+ version = args [0 ] if len (args ) > 0 else None
428+ input = args [1 ] if len (args ) > 1 else input
429+
430+ if sum (bool (x ) for x in [model , version , deployment ]) != 1 :
431+ raise ValueError (
432+ "Exactly one of 'model', 'version', or 'deployment' must be specified."
433+ )
434+
435+ if model is not None :
436+ from replicate .model import ( # pylint: disable=import-outside-toplevel
437+ Models ,
438+ )
439+
440+ return Models (self ._client ).predictions .create (
441+ model = model ,
442+ input = input or {},
443+ ** params ,
444+ )
445+
446+ if deployment is not None :
447+ from replicate .deployment import ( # pylint: disable=import-outside-toplevel
448+ Deployments ,
449+ )
450+
451+ return Deployments (self ._client ).predictions .create (
452+ deployment = deployment ,
453+ input = input or {},
454+ ** params ,
455+ )
456+
393457 body = _create_prediction_body (
394458 version ,
395459 input ,
396460 ** params ,
397461 )
462+
398463 resp = self ._client ._request (
399464 "POST" ,
400465 "/v1/predictions" ,
@@ -403,21 +468,82 @@ def create(
403468
404469 return _json_to_prediction (self ._client , resp .json ())
405470
471+ @overload
406472 async def async_create (
407473 self ,
408474 version : Union [Version , str ],
409475 input : Optional [Dict [str , Any ]],
410476 ** params : Unpack ["Predictions.CreatePredictionParams" ],
477+ ) -> Prediction : ...
478+
479+ @overload
480+ async def async_create (
481+ self ,
482+ * ,
483+ model : Union [str , Tuple [str , str ], "Model" ],
484+ input : Optional [Dict [str , Any ]],
485+ ** params : Unpack ["Predictions.CreatePredictionParams" ],
486+ ) -> Prediction : ...
487+
488+ @overload
489+ async def async_create (
490+ self ,
491+ * ,
492+ deployment : Union [str , Tuple [str , str ], "Deployment" ],
493+ input : Optional [Dict [str , Any ]],
494+ ** params : Unpack ["Predictions.CreatePredictionParams" ],
495+ ) -> Prediction : ...
496+
497+ async def async_create ( # type: ignore
498+ self ,
499+ * args ,
500+ model : Optional [Union [str , Tuple [str , str ], "Model" ]] = None ,
501+ version : Optional [Union [Version , str , "Version" ]] = None ,
502+ deployment : Optional [Union [str , Tuple [str , str ], "Deployment" ]] = None ,
503+ input : Optional [Dict [str , Any ]] = None ,
504+ ** params : Unpack ["Predictions.CreatePredictionParams" ],
411505 ) -> Prediction :
412506 """
413- Create a new prediction for the specified model version.
507+ Create a new prediction for the specified model, version, or deployment .
414508 """
415509
510+ if args :
511+ version = args [0 ] if len (args ) > 0 else None
512+ input = args [1 ] if len (args ) > 1 else input
513+
514+ if sum (bool (x ) for x in [model , version , deployment ]) != 1 :
515+ raise ValueError (
516+ "Exactly one of 'model', 'version', or 'deployment' must be specified."
517+ )
518+
519+ if model is not None :
520+ from replicate .model import ( # pylint: disable=import-outside-toplevel
521+ Models ,
522+ )
523+
524+ return await Models (self ._client ).predictions .async_create (
525+ model = model ,
526+ input = input or {},
527+ ** params ,
528+ )
529+
530+ if deployment is not None :
531+ from replicate .deployment import ( # pylint: disable=import-outside-toplevel
532+ Deployments ,
533+ )
534+
535+ return await Deployments (self ._client ).predictions .async_create (
536+ deployment = deployment ,
537+ input = input or {},
538+ ** params ,
539+ )
540+
416541 body = _create_prediction_body (
417542 version ,
418543 input ,
419544 ** params ,
420545 )
546+
421547 resp = await self ._client ._async_request (
422548 "POST" ,
423549 "/v1/predictions" ,
0 commit comments