1010
1111from replicate import ReplicateClient , AsyncReplicateClient
1212from replicate .lib ._files import FileOutput , AsyncFileOutput
13- from replicate ._exceptions import ModelError
13+ from replicate ._exceptions import ModelError , NotFoundError , BadRequestError
1414
1515base_url = os .environ .get ("TEST_API_BASE_URL" , "http://127.0.0.1:4010" )
1616bearer_token = "My Bearer Token"
1717
1818
1919# Mock prediction data for testing
2020def create_mock_prediction (
21- status : str = "succeeded" , output : Any = "test output" , error : Optional [str ] = None
21+ status : str = "succeeded" ,
22+ output : Any = "test output" ,
23+ error : Optional [str ] = None ,
24+ logs : Optional [str ] = None ,
25+ urls : Optional [Dict [str , str ]] = None ,
2226) -> Dict [str , Any ]:
27+ if urls is None :
28+ urls = {
29+ "get" : "https://api.replicate.com/v1/predictions/test_prediction_id" ,
30+ "cancel" : "https://api.replicate.com/v1/predictions/test_prediction_id/cancel" ,
31+ }
32+
2333 return {
2434 "id" : "test_prediction_id" ,
2535 "version" : "test_version" ,
2636 "status" : status ,
2737 "input" : {"prompt" : "test prompt" },
2838 "output" : output ,
2939 "error" : error ,
40+ "logs" : logs ,
3041 "created_at" : "2023-01-01T00:00:00Z" ,
3142 "started_at" : "2023-01-01T00:00:01Z" ,
3243 "completed_at" : "2023-01-01T00:00:02Z" if status in ["succeeded" , "failed" ] else None ,
33- "urls" : {
34- "get" : "https://api.replicate.com/v1/predictions/test_prediction_id" ,
35- "cancel" : "https://api.replicate.com/v1/predictions/test_prediction_id/cancel" ,
36- },
44+ "urls" : urls ,
3745 "model" : "test-model" ,
3846 "data_removed" : False ,
3947 }
@@ -45,7 +53,6 @@ class TestRun:
4553 @pytest .mark .respx (base_url = base_url )
4654 def test_run_basic (self , respx_mock : MockRouter ) -> None :
4755 """Test basic model run functionality."""
48- # Mock the prediction creation
4956 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
5057
5158 output : Any = self .client .run ("some-model-ref" , input = {"prompt" : "test prompt" })
@@ -55,7 +62,6 @@ def test_run_basic(self, respx_mock: MockRouter) -> None:
5562 @pytest .mark .respx (base_url = base_url )
5663 def test_run_with_wait_true (self , respx_mock : MockRouter ) -> None :
5764 """Test run with wait=True parameter."""
58- # Mock the prediction creation
5965 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
6066
6167 output : Any = self .client .run ("some-model-ref" , wait = True , input = {"prompt" : "test prompt" })
@@ -65,7 +71,6 @@ def test_run_with_wait_true(self, respx_mock: MockRouter) -> None:
6571 @pytest .mark .respx (base_url = base_url )
6672 def test_run_with_wait_int (self , respx_mock : MockRouter ) -> None :
6773 """Test run with wait as an integer value."""
68- # Mock the prediction creation
6974 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
7075
7176 output : Any = self .client .run ("some-model-ref" , wait = 10 , input = {"prompt" : "test prompt" })
@@ -167,14 +172,78 @@ def test_run_with_prefer_conflict(self) -> None:
167172 with pytest .raises (TypeError , match = "cannot mix and match prefer and wait" ):
168173 self .client .run ("some-model-ref" , wait = True , prefer = "nowait" , input = {"prompt" : "test" })
169174
175+ @pytest .mark .respx (base_url = base_url )
176+ def test_run_with_iterator (self , respx_mock : MockRouter ) -> None :
177+ """Test run with an iterator output."""
178+ # Create a mock prediction with an iterator output
179+ output_iterator = ["chunk1" , "chunk2" , "chunk3" ]
180+ respx_mock .post ("/predictions" ).mock (
181+ return_value = httpx .Response (201 , json = create_mock_prediction (output = output_iterator ))
182+ )
183+
184+ output = self .client .run ("some-model-ref" , input = {"prompt" : "generate iterator" })
185+
186+ assert isinstance (output , list )
187+ assert len (output ) == 3
188+ assert output == output_iterator
189+
190+ @pytest .mark .respx (base_url = base_url )
191+ def test_run_with_invalid_identifier (self , respx_mock : MockRouter ) -> None :
192+ """Test run with an invalid model identifier."""
193+ # Mock a 404 response for an invalid model identifier
194+ respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (404 , json = {"detail" : "Model not found" }))
195+
196+ with pytest .raises (NotFoundError ):
197+ self .client .run ("invalid-model-ref" , input = {"prompt" : "test prompt" })
198+
199+ @pytest .mark .respx (base_url = base_url )
200+ def test_run_with_invalid_cog_version (self , respx_mock : MockRouter ) -> None :
201+ """Test run with an invalid Cog version."""
202+ # Mock an error response for an invalid Cog version
203+ respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (400 , json = {"detail" : "Invalid Cog version" }))
204+
205+ with pytest .raises (BadRequestError ):
206+ self .client .run ("model-with-invalid-cog" , input = {"prompt" : "test prompt" })
207+
208+ @pytest .mark .respx (base_url = base_url )
209+ def test_run_with_file_output_iterator (self , respx_mock : MockRouter ) -> None :
210+ """Test run with file output iterator."""
211+ # Mock URLs for file outputs
212+ file_urls = [
213+ "https://replicate.delivery/output1.png" ,
214+ "https://replicate.delivery/output2.png" ,
215+ "https://replicate.delivery/output3.png" ,
216+ ]
217+
218+ # Initial response with processing status and no output
219+ respx_mock .post ("/predictions" ).mock (
220+ return_value = httpx .Response (201 , json = create_mock_prediction (status = "processing" , output = None ))
221+ )
222+
223+ # First poll returns still processing
224+ respx_mock .get ("/predictions/test_prediction_id" ).mock (
225+ return_value = httpx .Response (200 , json = create_mock_prediction (status = "processing" , output = None ))
226+ )
227+
228+ # Second poll returns success with file URLs
229+ respx_mock .get ("/predictions/test_prediction_id" ).mock (
230+ return_value = httpx .Response (200 , json = create_mock_prediction (output = file_urls ))
231+ )
232+
233+ output = self .client .run ("some-model-ref" , input = {"prompt" : "generate file iterator" })
234+
235+ assert isinstance (output , list )
236+ assert len (output ) == 3
237+ assert all (isinstance (item , FileOutput ) for item in output )
238+ assert [item .url for item in output ] == file_urls
239+
170240
171241class TestAsyncRun :
172242 client = AsyncReplicateClient (base_url = base_url , bearer_token = bearer_token , _strict_response_validation = True )
173243
174244 @pytest .mark .respx (base_url = base_url )
175245 async def test_async_run_basic (self , respx_mock : MockRouter ) -> None :
176246 """Test basic async model run functionality."""
177- # Mock the prediction creation
178247 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
179248
180249 output : Any = await self .client .run ("some-model-ref" , input = {"prompt" : "test prompt" })
@@ -184,7 +253,6 @@ async def test_async_run_basic(self, respx_mock: MockRouter) -> None:
184253 @pytest .mark .respx (base_url = base_url )
185254 async def test_async_run_with_wait_true (self , respx_mock : MockRouter ) -> None :
186255 """Test async run with wait=True parameter."""
187- # Mock the prediction creation
188256 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
189257
190258 output : Any = await self .client .run ("some-model-ref" , wait = True , input = {"prompt" : "test prompt" })
@@ -194,7 +262,6 @@ async def test_async_run_with_wait_true(self, respx_mock: MockRouter) -> None:
194262 @pytest .mark .respx (base_url = base_url )
195263 async def test_async_run_with_wait_int (self , respx_mock : MockRouter ) -> None :
196264 """Test async run with wait as an integer value."""
197- # Mock the prediction creation
198265 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
199266
200267 output : Any = await self .client .run ("some-model-ref" , wait = 10 , input = {"prompt" : "test prompt" })
@@ -299,3 +366,68 @@ async def test_async_run_with_prefer_conflict(self) -> None:
299366 """Test async run with conflicting wait and prefer parameters."""
300367 with pytest .raises (TypeError , match = "cannot mix and match prefer and wait" ):
301368 await self .client .run ("some-model-ref" , wait = True , prefer = "nowait" , input = {"prompt" : "test" })
369+
370+ @pytest .mark .respx (base_url = base_url )
371+ async def test_async_run_with_iterator (self , respx_mock : MockRouter ) -> None :
372+ """Test async run with an iterator output."""
373+ # Create a mock prediction with an iterator output
374+ output_iterator = ["chunk1" , "chunk2" , "chunk3" ]
375+ respx_mock .post ("/predictions" ).mock (
376+ return_value = httpx .Response (201 , json = create_mock_prediction (output = output_iterator ))
377+ )
378+
379+ output = await self .client .run ("some-model-ref" , input = {"prompt" : "generate iterator" })
380+
381+ assert isinstance (output , list )
382+ assert len (output ) == 3
383+ assert output == output_iterator
384+
385+ @pytest .mark .respx (base_url = base_url )
386+ async def test_async_run_with_invalid_identifier (self , respx_mock : MockRouter ) -> None :
387+ """Test async run with an invalid model identifier."""
388+ # Mock a 404 response for an invalid model identifier
389+ respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (404 , json = {"detail" : "Model not found" }))
390+
391+ with pytest .raises (NotFoundError ):
392+ await self .client .run ("invalid-model-ref" , input = {"prompt" : "test prompt" })
393+
394+ @pytest .mark .respx (base_url = base_url )
395+ async def test_async_run_with_invalid_cog_version (self , respx_mock : MockRouter ) -> None :
396+ """Test async run with an invalid Cog version."""
397+ # Mock an error response for an invalid Cog version
398+ respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (400 , json = {"detail" : "Invalid Cog version" }))
399+
400+ with pytest .raises (BadRequestError ):
401+ await self .client .run ("model-with-invalid-cog" , input = {"prompt" : "test prompt" })
402+
403+ @pytest .mark .respx (base_url = base_url )
404+ async def test_async_run_with_file_output_iterator (self , respx_mock : MockRouter ) -> None :
405+ """Test async run with file output iterator."""
406+ # Mock URLs for file outputs
407+ file_urls = [
408+ "https://replicate.delivery/output1.png" ,
409+ "https://replicate.delivery/output2.png" ,
410+ "https://replicate.delivery/output3.png" ,
411+ ]
412+
413+ # Initial response with processing status and no output
414+ respx_mock .post ("/predictions" ).mock (
415+ return_value = httpx .Response (201 , json = create_mock_prediction (status = "processing" , output = None ))
416+ )
417+
418+ # First poll returns still processing
419+ respx_mock .get ("/predictions/test_prediction_id" ).mock (
420+ return_value = httpx .Response (200 , json = create_mock_prediction (status = "processing" , output = None ))
421+ )
422+
423+ # Second poll returns success with file URLs
424+ respx_mock .get ("/predictions/test_prediction_id" ).mock (
425+ return_value = httpx .Response (200 , json = create_mock_prediction (output = file_urls ))
426+ )
427+
428+ output = await self .client .run ("some-model-ref" , input = {"prompt" : "generate file iterator" })
429+
430+ assert isinstance (output , list )
431+ assert len (output ) == 3
432+ assert all (isinstance (item , AsyncFileOutput ) for item in output )
433+ assert [item .url for item in output ] == file_urls
0 commit comments