1- import asyncio
21from typing import TYPE_CHECKING , Any , Dict , Iterator , List , Optional , Union
32
43from typing_extensions import Unpack
54
5+ from replicate import identifier
66from replicate .exceptions import ModelError
7- from replicate .identifier import ModelVersionIdentifier
7+ from replicate .model import Model
8+ from replicate .prediction import Prediction
89from replicate .schema import make_schema_backwards_compatible
9- from replicate .version import Versions
10+ from replicate .version import Version , Versions
1011
1112if TYPE_CHECKING :
1213 from replicate .client import Client
14+ from replicate .identifier import ModelVersionIdentifier
1315 from replicate .prediction import Predictions
1416
1517
1618def run (
1719 client : "Client" ,
18- ref : str ,
20+ ref : Union [ "Model" , "Version" , "ModelVersionIdentifier" , str ] ,
1921 input : Optional [Dict [str , Any ]] = None ,
2022 ** params : Unpack ["Predictions.CreatePredictionParams" ],
2123) -> Union [Any , Iterator [Any ]]: # noqa: ANN401
2224 """
2325 Run a model and wait for its output.
2426 """
2527
26- owner , name , version_id = ModelVersionIdentifier . parse (ref )
28+ version , owner , name , version_id = identifier . _resolve (ref )
2729
28- prediction = client .predictions .create (
29- version = version_id , input = input or {}, ** params
30- )
30+ if version_id is not None :
31+ prediction = client .predictions .create (
32+ version = version_id , input = input or {}, ** params
33+ )
34+ elif owner and name :
35+ prediction = client .models .predictions .create (
36+ model = (owner , name ), input = input or {}, ** params
37+ )
38+ else :
39+ raise ValueError (
40+ f"Invalid argument: { ref } . Expected model, version, or reference in the format owner/name or owner/name:version"
41+ )
3142
32- if owner and name :
43+ if not version and ( owner and name and version_id ) :
3344 version = Versions (client , model = (owner , name )).get (version_id )
3445
35- # Return an iterator of the output
36- schema = make_schema_backwards_compatible (
37- version .openapi_schema , version .cog_version
38- )
39- output = schema ["components" ]["schemas" ]["Output" ]
40- if (
41- output .get ("type" ) == "array"
42- and output .get ("x-cog-array-type" ) == "iterator"
43- ):
44- return prediction .output_iterator ()
46+ if version and (iterator := _make_output_iterator (version , prediction )):
47+ return iterator
4548
4649 prediction .wait ()
4750
@@ -53,42 +56,54 @@ def run(
5356
5457async def async_run (
5558 client : "Client" ,
56- ref : str ,
59+ ref : Union [ "Model" , "Version" , "ModelVersionIdentifier" , str ] ,
5760 input : Optional [Dict [str , Any ]] = None ,
5861 ** params : Unpack ["Predictions.CreatePredictionParams" ],
5962) -> Union [Any , Iterator [Any ]]: # noqa: ANN401
6063 """
6164 Run a model and wait for its output asynchronously.
6265 """
6366
64- owner , name , version_id = ModelVersionIdentifier . parse (ref )
67+ version , owner , name , version_id = identifier . _resolve (ref )
6568
66- prediction = await client .predictions .async_create (
67- version = version_id , input = input or {}, ** params
68- )
69+ if version or version_id :
70+ prediction = await client .predictions .async_create (
71+ version = (version or version_id ), input = input or {}, ** params
72+ )
73+ elif owner and name :
74+ prediction = await client .models .predictions .async_create (
75+ model = (owner , name ), input = input or {}, ** params
76+ )
77+ else :
78+ raise ValueError (
79+ f"Invalid argument: { ref } . Expected model, version, or reference in the format owner/name or owner/name:version"
80+ )
6981
70- if owner and name :
71- version = await Versions (client , model = (owner , name )).async_get (version_id )
82+ if not version and ( owner and name and version_id ) :
83+ version = Versions (client , model = (owner , name )).get (version_id )
7284
73- # Return an iterator of the output
74- schema = make_schema_backwards_compatible (
75- version .openapi_schema , version .cog_version
76- )
77- output = schema ["components" ]["schemas" ]["Output" ]
78- if (
79- output .get ("type" ) == "array"
80- and output .get ("x-cog-array-type" ) == "iterator"
81- ):
82- return prediction .output_iterator ()
85+ if version and (iterator := _make_output_iterator (version , prediction )):
86+ return iterator
8387
84- while prediction .status not in ["succeeded" , "failed" , "canceled" ]:
85- await asyncio .sleep (client .poll_interval )
86- prediction = await client .predictions .async_get (prediction .id )
88+ prediction .wait ()
8789
8890 if prediction .status == "failed" :
8991 raise ModelError (prediction .error )
9092
9193 return prediction .output
9294
9395
96+ def _make_output_iterator (
97+ version : Version , prediction : Prediction
98+ ) -> Optional [Iterator [Any ]]:
99+ schema = make_schema_backwards_compatible (
100+ version .openapi_schema , version .cog_version
101+ )
102+ output = schema ["components" ]["schemas" ]["Output" ]
103+ if output .get ("type" ) == "array" and output .get ("x-cog-array-type" ) == "iterator" :
104+ return prediction .output_iterator ()
105+
106+ return None
107+
108+
94109__all__ : List = []
0 commit comments