33import io
44import os
55import datetime
6- from typing import Any , Dict , Optional
6+ from typing import Any , Dict , List , Optional
77
88import httpx
99import pytest
@@ -87,12 +87,15 @@ def _version_with_schema(id: str = "v1", output_schema: Optional[object] = None)
8787class TestRun :
8888 client = Replicate (base_url = base_url , bearer_token = bearer_token , _strict_response_validation = True )
8989
90+ # Common model reference format that will work with the new SDK
91+ model_ref = "owner/name:version"
92+
9093 @pytest .mark .respx (base_url = base_url )
9194 def test_run_basic (self , respx_mock : MockRouter ) -> None :
9295 """Test basic model run functionality."""
9396 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
9497
95- output : Any = self .client .run ("some-model-ref" , input = {"prompt" : "test prompt" })
98+ output : Any = self .client .run (self . model_ref , input = {"prompt" : "test prompt" })
9699
97100 assert output == "test output"
98101
@@ -101,7 +104,7 @@ def test_run_with_wait_true(self, respx_mock: MockRouter) -> None:
101104 """Test run with wait=True parameter."""
102105 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
103106
104- output : Any = self .client .run ("some-model-ref" , wait = True , input = {"prompt" : "test prompt" })
107+ output : Any = self .client .run (self . model_ref , wait = True , input = {"prompt" : "test prompt" })
105108
106109 assert output == "test output"
107110
@@ -110,7 +113,7 @@ def test_run_with_wait_int(self, respx_mock: MockRouter) -> None:
110113 """Test run with wait as an integer value."""
111114 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
112115
113- output : Any = self .client .run ("some-model-ref" , wait = 10 , input = {"prompt" : "test prompt" })
116+ output : Any = self .client .run (self . model_ref , wait = 10 , input = {"prompt" : "test prompt" })
114117
115118 assert output == "test output"
116119
@@ -127,7 +130,7 @@ def test_run_without_wait(self, respx_mock: MockRouter) -> None:
127130 return_value = httpx .Response (200 , json = create_mock_prediction (status = "succeeded" ))
128131 )
129132
130- output : Any = self .client .run ("some-model-ref" , wait = False , input = {"prompt" : "test prompt" })
133+ output : Any = self .client .run (self . model_ref , wait = False , input = {"prompt" : "test prompt" })
131134
132135 assert output == "test output"
133136
@@ -140,11 +143,38 @@ def test_run_with_file_output(self, respx_mock: MockRouter) -> None:
140143 return_value = httpx .Response (201 , json = create_mock_prediction (output = file_url ))
141144 )
142145
143- output : Any = self .client .run ("some-model-ref" , input = {"prompt" : "generate image" })
146+ output : Any = self .client .run (self . model_ref , input = {"prompt" : "generate image" })
144147
145148 assert isinstance (output , FileOutput )
146149 assert output .url == file_url
147150
151+ @pytest .mark .respx (base_url = base_url )
152+ def test_run_with_data_uri_output (self , respx_mock : MockRouter ) -> None :
153+ """Test run with data URI output."""
154+ # Create a data URI for a small PNG image (1x1 transparent pixel)
155+ data_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
156+
157+ # Mock prediction with data URI output
158+ respx_mock .post ("/predictions" ).mock (
159+ return_value = httpx .Response (201 , json = create_mock_prediction (output = data_uri ))
160+ )
161+
162+ # Use a valid model version ID format
163+ output : Any = self .client .run ("owner/name:version" , input = {"prompt" : "generate small image" })
164+
165+ assert isinstance (output , FileOutput )
166+ assert output .url == data_uri
167+
168+ # Test that we can read the data
169+ image_data = output .read ()
170+ assert isinstance (image_data , bytes )
171+ assert len (image_data ) > 0
172+
173+ # Test that we can iterate over the data
174+ chunks = list (output )
175+ assert len (chunks ) == 1
176+ assert chunks [0 ] == image_data
177+
148178 @pytest .mark .respx (base_url = base_url )
149179 def test_run_with_file_list_output (self , respx_mock : MockRouter ) -> None :
150180 """Test run with list of file outputs."""
@@ -157,7 +187,7 @@ def test_run_with_file_list_output(self, respx_mock: MockRouter) -> None:
157187 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = mock_prediction ))
158188
159189 output : list [FileOutput ] = self .client .run (
160- "some-model-ref" , use_file_output = True , input = {"prompt" : "generate multiple images" }
190+ self . model_ref , use_file_output = True , input = {"prompt" : "generate multiple images" }
161191 )
162192
163193 assert isinstance (output , list )
@@ -176,7 +206,7 @@ def test_run_with_dict_file_output(self, respx_mock: MockRouter) -> None:
176206 return_value = httpx .Response (201 , json = create_mock_prediction (output = file_urls ))
177207 )
178208
179- output : Dict [str , FileOutput ] = self .client .run ("some-model-ref" , input = {"prompt" : "structured output" })
209+ output : Dict [str , FileOutput ] = self .client .run (self . model_ref , input = {"prompt" : "structured output" })
180210
181211 assert isinstance (output , dict )
182212 assert len (output ) == 2
@@ -191,7 +221,7 @@ def test_run_with_error(self, respx_mock: MockRouter) -> None:
191221 )
192222
193223 with pytest .raises (ModelError ):
194- self .client .run ("error-model-ref" , input = {"prompt" : "trigger error" })
224+ self .client .run (self . model_ref , input = {"prompt" : "trigger error" })
195225
196226 @pytest .mark .respx (base_url = base_url )
197227 def test_run_with_base64_file (self , respx_mock : MockRouter ) -> None :
@@ -202,14 +232,14 @@ def test_run_with_base64_file(self, respx_mock: MockRouter) -> None:
202232 # Mock the prediction response
203233 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
204234
205- output : Any = self .client .run ("some-model-ref" , input = {"file" : file_obj }, file_encoding_strategy = "base64" )
235+ output : Any = self .client .run (self . model_ref , input = {"file" : file_obj }, file_encoding_strategy = "base64" )
206236
207237 assert output == "test output"
208238
209239 def test_run_with_prefer_conflict (self ) -> None :
210240 """Test run with conflicting wait and prefer parameters."""
211241 with pytest .raises (TypeError , match = "cannot mix and match prefer and wait" ):
212- self .client .run ("some-model-ref" , wait = True , prefer = "nowait" , input = {"prompt" : "test" })
242+ self .client .run (self . model_ref , wait = True , prefer = "nowait" , input = {"prompt" : "test" })
213243
214244 @pytest .mark .respx (base_url = base_url )
215245 def test_run_with_iterator (self , respx_mock : MockRouter ) -> None :
@@ -220,7 +250,7 @@ def test_run_with_iterator(self, respx_mock: MockRouter) -> None:
220250 return_value = httpx .Response (201 , json = create_mock_prediction (output = output_iterator ))
221251 )
222252
223- output : list [str ] = self .client .run ("some-model-ref" , input = {"prompt" : "generate iterator" })
253+ output : list [str ] = self .client .run (self . model_ref , input = {"prompt" : "generate iterator" })
224254
225255 assert isinstance (output , list )
226256 assert len (output ) == 3
@@ -233,7 +263,7 @@ def test_run_with_invalid_identifier(self, respx_mock: MockRouter) -> None:
233263 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (404 , json = {"detail" : "Model not found" }))
234264
235265 with pytest .raises (NotFoundError ):
236- self .client .run ("invalid- model- ref" , input = {"prompt" : "test prompt" })
266+ self .client .run ("invalid/ model: ref" , input = {"prompt" : "test prompt" })
237267
238268 @pytest .mark .respx (base_url = base_url )
239269 def test_run_with_invalid_cog_version (self , respx_mock : MockRouter ) -> None :
@@ -242,7 +272,7 @@ def test_run_with_invalid_cog_version(self, respx_mock: MockRouter) -> None:
242272 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (400 , json = {"detail" : "Invalid Cog version" }))
243273
244274 with pytest .raises (BadRequestError ):
245- self .client .run ("model-with- invalid- cog" , input = {"prompt" : "test prompt" })
275+ self .client .run ("invalid/ cog:model " , input = {"prompt" : "test prompt" })
246276
247277 @pytest .mark .respx (base_url = base_url )
248278 def test_run_with_model_object (self , respx_mock : MockRouter ) -> None :
@@ -274,9 +304,7 @@ def test_run_with_model_version_identifier(self, respx_mock: MockRouter) -> None
274304 # Case where version ID is provided
275305 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
276306
277- identifier = ModelVersionIdentifier (
278- owner = "test-owner" , name = "test-model" , version = "test-version-id"
279- )
307+ identifier = ModelVersionIdentifier (owner = "test-owner" , name = "test-model" , version = "test-version-id" )
280308 output = self .client .run (identifier , input = {"prompt" : "test prompt" })
281309
282310 assert output == "test output"
@@ -307,7 +335,7 @@ def test_run_with_file_output_iterator(self, respx_mock: MockRouter) -> None:
307335 )
308336
309337 output : list [FileOutput ] = self .client .run (
310- "some-model-ref" , use_file_output = True , wait = False , input = {"prompt" : "generate file iterator" }
338+ self . model_ref , use_file_output = True , wait = False , input = {"prompt" : "generate file iterator" }
311339 )
312340
313341 assert isinstance (output , list )
@@ -319,12 +347,15 @@ def test_run_with_file_output_iterator(self, respx_mock: MockRouter) -> None:
319347class TestAsyncRun :
320348 client = AsyncReplicate (base_url = base_url , bearer_token = bearer_token , _strict_response_validation = True )
321349
350+ # Common model reference format that will work with the new SDK
351+ model_ref = "owner/name:version"
352+
322353 @pytest .mark .respx (base_url = base_url )
323354 async def test_async_run_basic (self , respx_mock : MockRouter ) -> None :
324355 """Test basic async model run functionality."""
325356 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
326357
327- output : Any = await self .client .run ("some-model-ref" , input = {"prompt" : "test prompt" })
358+ output : Any = await self .client .run (self . model_ref , input = {"prompt" : "test prompt" })
328359
329360 assert output == "test output"
330361
@@ -333,7 +364,7 @@ async def test_async_run_with_wait_true(self, respx_mock: MockRouter) -> None:
333364 """Test async run with wait=True parameter."""
334365 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
335366
336- output : Any = await self .client .run ("some-model-ref" , wait = True , input = {"prompt" : "test prompt" })
367+ output : Any = await self .client .run (self . model_ref , wait = True , input = {"prompt" : "test prompt" })
337368
338369 assert output == "test output"
339370
@@ -342,7 +373,7 @@ async def test_async_run_with_wait_int(self, respx_mock: MockRouter) -> None:
342373 """Test async run with wait as an integer value."""
343374 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
344375
345- output : Any = await self .client .run ("some-model-ref" , wait = 10 , input = {"prompt" : "test prompt" })
376+ output : Any = await self .client .run (self . model_ref , wait = 10 , input = {"prompt" : "test prompt" })
346377
347378 assert output == "test output"
348379
@@ -359,7 +390,7 @@ async def test_async_run_without_wait(self, respx_mock: MockRouter) -> None:
359390 return_value = httpx .Response (200 , json = create_mock_prediction (status = "succeeded" ))
360391 )
361392
362- output : Any = await self .client .run ("some-model-ref" , wait = False , input = {"prompt" : "test prompt" })
393+ output : Any = await self .client .run (self . model_ref , wait = False , input = {"prompt" : "test prompt" })
363394
364395 assert output == "test output"
365396
@@ -372,11 +403,41 @@ async def test_async_run_with_file_output(self, respx_mock: MockRouter) -> None:
372403 return_value = httpx .Response (201 , json = create_mock_prediction (output = file_url ))
373404 )
374405
375- output : Any = await self .client .run ("some-model-ref" , input = {"prompt" : "generate image" })
406+ output : Any = await self .client .run (self . model_ref , input = {"prompt" : "generate image" })
376407
377408 assert isinstance (output , AsyncFileOutput )
378409 assert output .url == file_url
379410
411+ @pytest .mark .respx (base_url = base_url )
412+ async def test_async_run_with_data_uri_output (self , respx_mock : MockRouter ) -> None :
413+ """Test async run with data URI output."""
414+ # Create a data URI for a small PNG image (1x1 transparent pixel)
415+ data_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
416+
417+ # Mock prediction with data URI output
418+ respx_mock .post ("/predictions" ).mock (
419+ return_value = httpx .Response (201 , json = create_mock_prediction (output = data_uri ))
420+ )
421+
422+ # Use a valid model version ID format
423+ output : Any = await self .client .run ("owner/name:version" , input = {"prompt" : "generate small image" })
424+
425+ assert isinstance (output , AsyncFileOutput )
426+ assert output .url == data_uri
427+
428+ # Test that we can read the data asynchronously
429+ image_data = await output .read ()
430+ assert isinstance (image_data , bytes )
431+ assert len (image_data ) > 0
432+
433+ # Test that we can iterate over the data asynchronously
434+ chunks : List [Any ] = []
435+ async for chunk in output :
436+ chunks .append (chunk )
437+
438+ assert len (chunks ) == 1
439+ assert chunks [0 ] == image_data
440+
380441 @pytest .mark .respx (base_url = base_url )
381442 async def test_async_run_with_file_list_output (self , respx_mock : MockRouter ) -> None :
382443 """Test async run with list of file outputs."""
@@ -389,7 +450,7 @@ async def test_async_run_with_file_list_output(self, respx_mock: MockRouter) ->
389450 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = mock_prediction ))
390451
391452 output : list [AsyncFileOutput ] = await self .client .run (
392- "some-model-ref" , input = {"prompt" : "generate multiple images" }
453+ self . model_ref , input = {"prompt" : "generate multiple images" }
393454 )
394455
395456 assert isinstance (output , list )
@@ -409,7 +470,7 @@ async def test_async_run_with_dict_file_output(self, respx_mock: MockRouter) ->
409470 )
410471
411472 output : Dict [str , AsyncFileOutput ] = await self .client .run (
412- "some-model-ref" , input = {"prompt" : "structured output" }
473+ self . model_ref , input = {"prompt" : "structured output" }
413474 )
414475
415476 assert isinstance (output , dict )
@@ -425,7 +486,7 @@ async def test_async_run_with_error(self, respx_mock: MockRouter) -> None:
425486 )
426487
427488 with pytest .raises (ModelError ):
428- await self .client .run ("error-model-ref" , input = {"prompt" : "trigger error" })
489+ await self .client .run (self . model_ref , input = {"prompt" : "trigger error" })
429490
430491 @pytest .mark .respx (base_url = base_url )
431492 async def test_async_run_with_base64_file (self , respx_mock : MockRouter ) -> None :
@@ -436,14 +497,14 @@ async def test_async_run_with_base64_file(self, respx_mock: MockRouter) -> None:
436497 # Mock the prediction response
437498 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
438499
439- output : Any = await self .client .run ("some-model-ref" , input = {"file" : file_obj }, file_encoding_strategy = "base64" )
500+ output : Any = await self .client .run (self . model_ref , input = {"file" : file_obj }, file_encoding_strategy = "base64" )
440501
441502 assert output == "test output"
442503
443504 async def test_async_run_with_prefer_conflict (self ) -> None :
444505 """Test async run with conflicting wait and prefer parameters."""
445506 with pytest .raises (TypeError , match = "cannot mix and match prefer and wait" ):
446- await self .client .run ("some-model-ref" , wait = True , prefer = "nowait" , input = {"prompt" : "test" })
507+ await self .client .run (self . model_ref , wait = True , prefer = "nowait" , input = {"prompt" : "test" })
447508
448509 @pytest .mark .respx (base_url = base_url )
449510 async def test_async_run_with_iterator (self , respx_mock : MockRouter ) -> None :
@@ -454,7 +515,7 @@ async def test_async_run_with_iterator(self, respx_mock: MockRouter) -> None:
454515 return_value = httpx .Response (201 , json = create_mock_prediction (output = output_iterator ))
455516 )
456517
457- output : list [str ] = await self .client .run ("some-model-ref" , input = {"prompt" : "generate iterator" })
518+ output : list [str ] = await self .client .run (self . model_ref , input = {"prompt" : "generate iterator" })
458519
459520 assert isinstance (output , list )
460521 assert len (output ) == 3
@@ -467,7 +528,7 @@ async def test_async_run_with_invalid_identifier(self, respx_mock: MockRouter) -
467528 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (404 , json = {"detail" : "Model not found" }))
468529
469530 with pytest .raises (NotFoundError ):
470- await self .client .run ("invalid- model- ref" , input = {"prompt" : "test prompt" })
531+ await self .client .run ("invalid/ model: ref" , input = {"prompt" : "test prompt" })
471532
472533 @pytest .mark .respx (base_url = base_url )
473534 async def test_async_run_with_invalid_cog_version (self , respx_mock : MockRouter ) -> None :
@@ -476,7 +537,7 @@ async def test_async_run_with_invalid_cog_version(self, respx_mock: MockRouter)
476537 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (400 , json = {"detail" : "Invalid Cog version" }))
477538
478539 with pytest .raises (BadRequestError ):
479- await self .client .run ("model-with- invalid- cog" , input = {"prompt" : "test prompt" })
540+ await self .client .run ("invalid/ cog:model " , input = {"prompt" : "test prompt" })
480541
481542 @pytest .mark .respx (base_url = base_url )
482543 async def test_async_run_with_model_object (self , respx_mock : MockRouter ) -> None :
@@ -508,9 +569,7 @@ async def test_async_run_with_model_version_identifier(self, respx_mock: MockRou
508569 # Case where version ID is provided
509570 respx_mock .post ("/predictions" ).mock (return_value = httpx .Response (201 , json = create_mock_prediction ()))
510571
511- identifier = ModelVersionIdentifier (
512- owner = "test-owner" , name = "test-model" , version = "test-version-id"
513- )
572+ identifier = ModelVersionIdentifier (owner = "test-owner" , name = "test-model" , version = "test-version-id" )
514573 output = await self .client .run (identifier , input = {"prompt" : "test prompt" })
515574
516575 assert output == "test output"
@@ -541,10 +600,34 @@ async def test_async_run_with_file_output_iterator(self, respx_mock: MockRouter)
541600 )
542601
543602 output : list [AsyncFileOutput ] = await self .client .run (
544- "some-model-ref" , use_file_output = True , wait = False , input = {"prompt" : "generate file iterator" }
603+ self . model_ref , use_file_output = True , wait = False , input = {"prompt" : "generate file iterator" }
545604 )
546605
547606 assert isinstance (output , list )
548607 assert len (output ) == 3
549608 assert all (isinstance (item , AsyncFileOutput ) for item in output )
550609 assert [item .url for item in output ] == file_urls
610+
611+ @pytest .mark .respx (base_url = base_url )
612+ async def test_async_run_concurrently (self , respx_mock : MockRouter ) -> None :
613+ """Test running multiple models concurrently with asyncio."""
614+ import asyncio
615+
616+ # Mock three different prediction responses
617+ mock_outputs = ["output1" , "output2" , "output3" ]
618+ prompts = ["prompt1" , "prompt2" , "prompt3" ]
619+
620+ # Set up mocks for each request (using side_effect to allow multiple matches)
621+ # Note: This will match any POST to /predictions but return different responses
622+ route = respx_mock .post ("/predictions" )
623+ route .side_effect = [httpx .Response (201 , json = create_mock_prediction (output = output )) for output in mock_outputs ]
624+
625+ # Run three predictions concurrently
626+ tasks = [self .client .run ("owner/name:version" , input = {"prompt" : prompt }) for prompt in prompts ]
627+
628+ results = await asyncio .gather (* tasks )
629+
630+ # Verify each result matches expected output
631+ assert len (results ) == 3
632+ for i , result in enumerate (results ):
633+ assert result == mock_outputs [i ]
0 commit comments