1- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union , overload
2-
3- from typing_extensions import Unpack
1+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union
42
53from replicate .base_model import BaseModel
64from replicate .collection import Collection
75from replicate .files import upload_file
86from replicate .json import encode_json
9- from replicate .prediction import Prediction , PredictionCollection
7+ from replicate .prediction import Prediction
108
119if TYPE_CHECKING :
1210 from replicate .client import Client
@@ -17,6 +15,8 @@ class Deployment(BaseModel):
1715 A deployment of a model hosted on Replicate.
1816 """
1917
18+ _collection : "DeploymentCollection"
19+
2020 username : str
2121 """
2222 The name of the user or organization that owns the deployment.
@@ -43,15 +43,6 @@ class DeploymentCollection(Collection):
4343
4444 model = Deployment
4545
46- def list (self ) -> List [Deployment ]:
47- """
48- List deployments.
49-
50- Raises:
51- NotImplementedError: This method is not implemented.
52- """
53- raise NotImplementedError ()
54-
5546 def get (self , name : str ) -> Deployment :
5647 """
5748 Get a deployment by name.
@@ -65,89 +56,35 @@ def get(self, name: str) -> Deployment:
6556 # TODO: fetch model from server
6657 # TODO: support permanent IDs
6758 username , name = name .split ("/" )
68- return self .prepare_model ({"username" : username , "name" : name })
69-
70- def create (
71- self ,
72- * args ,
73- ** kwargs ,
74- ) -> Deployment :
75- """
76- Create a deployment.
77-
78- Raises:
79- NotImplementedError: This method is not implemented.
80- """
81- raise NotImplementedError ()
59+ return self ._prepare_model ({"username" : username , "name" : name })
8260
83- def prepare_model (self , attrs : Union [Deployment , Dict ]) -> Deployment :
61+ def _prepare_model (self , attrs : Union [Deployment , Dict ]) -> Deployment :
8462 if isinstance (attrs , BaseModel ):
8563 attrs .id = f"{ attrs .username } /{ attrs .name } "
8664 elif isinstance (attrs , dict ):
8765 attrs ["id" ] = f"{ attrs ['username' ]} /{ attrs ['name' ]} "
88- return super ().prepare_model (attrs )
66+ return super ()._prepare_model (attrs )
8967
9068
9169class DeploymentPredictionCollection (Collection ):
70+ """
71+ Namespace for operations related to predictions in a deployment.
72+ """
73+
9274 model = Prediction
9375
9476 def __init__ (self , client : "Client" , deployment : Deployment ) -> None :
9577 super ().__init__ (client = client )
9678 self ._deployment = deployment
9779
98- def list (self ) -> List [Prediction ]:
99- """
100- List predictions in a deployment.
101-
102- Raises:
103- NotImplementedError: This method is not implemented.
104- """
105- raise NotImplementedError ()
106-
107- def get (self , id : str ) -> Prediction :
108- """
109- Get a prediction by ID.
110-
111- Args:
112- id: The ID of the prediction.
113- Returns:
114- Prediction: The prediction object.
115- """
116-
117- resp = self ._client ._request ("GET" , f"/v1/predictions/{ id } " )
118- obj = resp .json ()
119- # HACK: resolve this? make it lazy somehow?
120- del obj ["version" ]
121- return self .prepare_model (obj )
122-
123- @overload
124- def create ( # pylint: disable=arguments-differ disable=too-many-arguments
80+ def create (
12581 self ,
12682 input : Dict [str , Any ],
12783 * ,
12884 webhook : Optional [str ] = None ,
12985 webhook_completed : Optional [str ] = None ,
13086 webhook_events_filter : Optional [List [str ]] = None ,
13187 stream : Optional [bool ] = None ,
132- ) -> Prediction :
133- ...
134-
135- @overload
136- def create ( # pylint: disable=arguments-differ disable=too-many-arguments
137- self ,
138- * ,
139- input : Dict [str , Any ],
140- webhook : Optional [str ] = None ,
141- webhook_completed : Optional [str ] = None ,
142- webhook_events_filter : Optional [List [str ]] = None ,
143- stream : Optional [bool ] = None ,
144- ) -> Prediction :
145- ...
146-
147- def create (
148- self ,
149- * args ,
150- ** kwargs : Unpack [PredictionCollection .CreateParams ], # type: ignore[misc]
15188 ) -> Prediction :
15289 """
15390 Create a new prediction with the deployment.
@@ -163,20 +100,21 @@ def create(
163100 Prediction: The created prediction object.
164101 """
165102
166- input = args [0 ] if len (args ) > 0 else kwargs .get ("input" )
167- if input is None :
168- raise ValueError (
169- "An input must be provided as a positional or keyword argument."
170- )
171-
172103 body = {
173104 "input" : encode_json (input , upload_file = upload_file ),
174105 }
175106
176- for key in ["webhook" , "webhook_completed" , "webhook_events_filter" , "stream" ]:
177- value = kwargs .get (key )
178- if value is not None :
179- body [key ] = value
107+ if webhook is not None :
108+ body ["webhook" ] = webhook
109+
110+ if webhook_completed is not None :
111+ body ["webhook_completed" ] = webhook_completed
112+
113+ if webhook_events_filter is not None :
114+ body ["webhook_events_filter" ] = webhook_events_filter
115+
116+ if stream is not None :
117+ body ["stream" ] = stream
180118
181119 resp = self ._client ._request (
182120 "POST" ,
@@ -186,4 +124,4 @@ def create(
186124 obj = resp .json ()
187125 obj ["deployment" ] = self ._deployment
188126 del obj ["version" ]
189- return self .prepare_model (obj )
127+ return self ._prepare_model (obj )
0 commit comments