Skip to content

Commit 11e09d3

Browse files
committed
clean up testing a bit more
1 parent 24d30f4 commit 11e09d3

File tree

1 file changed

+118
-35
lines changed

1 file changed

+118
-35
lines changed

tests/lib/test_run.py

Lines changed: 118 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import io
44
import os
55
import datetime
6-
from typing import Any, Dict, Optional
6+
from typing import Any, Dict, List, Optional
77

88
import httpx
99
import pytest
@@ -87,12 +87,15 @@ def _version_with_schema(id: str = "v1", output_schema: Optional[object] = None)
8787
class TestRun:
8888
client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
8989

90+
# Common model reference format that will work with the new SDK
91+
model_ref = "owner/name:version"
92+
9093
@pytest.mark.respx(base_url=base_url)
9194
def test_run_basic(self, respx_mock: MockRouter) -> None:
9295
"""Test basic model run functionality."""
9396
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
9497

95-
output: Any = self.client.run("some-model-ref", input={"prompt": "test prompt"})
98+
output: Any = self.client.run(self.model_ref, input={"prompt": "test prompt"})
9699

97100
assert output == "test output"
98101

@@ -101,7 +104,7 @@ def test_run_with_wait_true(self, respx_mock: MockRouter) -> None:
101104
"""Test run with wait=True parameter."""
102105
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
103106

104-
output: Any = self.client.run("some-model-ref", wait=True, input={"prompt": "test prompt"})
107+
output: Any = self.client.run(self.model_ref, wait=True, input={"prompt": "test prompt"})
105108

106109
assert output == "test output"
107110

@@ -110,7 +113,7 @@ def test_run_with_wait_int(self, respx_mock: MockRouter) -> None:
110113
"""Test run with wait as an integer value."""
111114
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
112115

113-
output: Any = self.client.run("some-model-ref", wait=10, input={"prompt": "test prompt"})
116+
output: Any = self.client.run(self.model_ref, wait=10, input={"prompt": "test prompt"})
114117

115118
assert output == "test output"
116119

@@ -127,7 +130,7 @@ def test_run_without_wait(self, respx_mock: MockRouter) -> None:
127130
return_value=httpx.Response(200, json=create_mock_prediction(status="succeeded"))
128131
)
129132

130-
output: Any = self.client.run("some-model-ref", wait=False, input={"prompt": "test prompt"})
133+
output: Any = self.client.run(self.model_ref, wait=False, input={"prompt": "test prompt"})
131134

132135
assert output == "test output"
133136

@@ -140,11 +143,38 @@ def test_run_with_file_output(self, respx_mock: MockRouter) -> None:
140143
return_value=httpx.Response(201, json=create_mock_prediction(output=file_url))
141144
)
142145

143-
output: Any = self.client.run("some-model-ref", input={"prompt": "generate image"})
146+
output: Any = self.client.run(self.model_ref, input={"prompt": "generate image"})
144147

145148
assert isinstance(output, FileOutput)
146149
assert output.url == file_url
147150

151+
@pytest.mark.respx(base_url=base_url)
152+
def test_run_with_data_uri_output(self, respx_mock: MockRouter) -> None:
153+
"""Test run with data URI output."""
154+
# Create a data URI for a small PNG image (1x1 transparent pixel)
155+
data_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
156+
157+
# Mock prediction with data URI output
158+
respx_mock.post("/predictions").mock(
159+
return_value=httpx.Response(201, json=create_mock_prediction(output=data_uri))
160+
)
161+
162+
# Use a valid model version ID format
163+
output: Any = self.client.run("owner/name:version", input={"prompt": "generate small image"})
164+
165+
assert isinstance(output, FileOutput)
166+
assert output.url == data_uri
167+
168+
# Test that we can read the data
169+
image_data = output.read()
170+
assert isinstance(image_data, bytes)
171+
assert len(image_data) > 0
172+
173+
# Test that we can iterate over the data
174+
chunks = list(output)
175+
assert len(chunks) == 1
176+
assert chunks[0] == image_data
177+
148178
@pytest.mark.respx(base_url=base_url)
149179
def test_run_with_file_list_output(self, respx_mock: MockRouter) -> None:
150180
"""Test run with list of file outputs."""
@@ -157,7 +187,7 @@ def test_run_with_file_list_output(self, respx_mock: MockRouter) -> None:
157187
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=mock_prediction))
158188

159189
output: list[FileOutput] = self.client.run(
160-
"some-model-ref", use_file_output=True, input={"prompt": "generate multiple images"}
190+
self.model_ref, use_file_output=True, input={"prompt": "generate multiple images"}
161191
)
162192

163193
assert isinstance(output, list)
@@ -176,7 +206,7 @@ def test_run_with_dict_file_output(self, respx_mock: MockRouter) -> None:
176206
return_value=httpx.Response(201, json=create_mock_prediction(output=file_urls))
177207
)
178208

179-
output: Dict[str, FileOutput] = self.client.run("some-model-ref", input={"prompt": "structured output"})
209+
output: Dict[str, FileOutput] = self.client.run(self.model_ref, input={"prompt": "structured output"})
180210

181211
assert isinstance(output, dict)
182212
assert len(output) == 2
@@ -191,7 +221,7 @@ def test_run_with_error(self, respx_mock: MockRouter) -> None:
191221
)
192222

193223
with pytest.raises(ModelError):
194-
self.client.run("error-model-ref", input={"prompt": "trigger error"})
224+
self.client.run(self.model_ref, input={"prompt": "trigger error"})
195225

196226
@pytest.mark.respx(base_url=base_url)
197227
def test_run_with_base64_file(self, respx_mock: MockRouter) -> None:
@@ -202,14 +232,14 @@ def test_run_with_base64_file(self, respx_mock: MockRouter) -> None:
202232
# Mock the prediction response
203233
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
204234

205-
output: Any = self.client.run("some-model-ref", input={"file": file_obj}, file_encoding_strategy="base64")
235+
output: Any = self.client.run(self.model_ref, input={"file": file_obj}, file_encoding_strategy="base64")
206236

207237
assert output == "test output"
208238

209239
def test_run_with_prefer_conflict(self) -> None:
210240
"""Test run with conflicting wait and prefer parameters."""
211241
with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):
212-
self.client.run("some-model-ref", wait=True, prefer="nowait", input={"prompt": "test"})
242+
self.client.run(self.model_ref, wait=True, prefer="nowait", input={"prompt": "test"})
213243

214244
@pytest.mark.respx(base_url=base_url)
215245
def test_run_with_iterator(self, respx_mock: MockRouter) -> None:
@@ -220,7 +250,7 @@ def test_run_with_iterator(self, respx_mock: MockRouter) -> None:
220250
return_value=httpx.Response(201, json=create_mock_prediction(output=output_iterator))
221251
)
222252

223-
output: list[str] = self.client.run("some-model-ref", input={"prompt": "generate iterator"})
253+
output: list[str] = self.client.run(self.model_ref, input={"prompt": "generate iterator"})
224254

225255
assert isinstance(output, list)
226256
assert len(output) == 3
@@ -233,7 +263,7 @@ def test_run_with_invalid_identifier(self, respx_mock: MockRouter) -> None:
233263
respx_mock.post("/predictions").mock(return_value=httpx.Response(404, json={"detail": "Model not found"}))
234264

235265
with pytest.raises(NotFoundError):
236-
self.client.run("invalid-model-ref", input={"prompt": "test prompt"})
266+
self.client.run("invalid/model:ref", input={"prompt": "test prompt"})
237267

238268
@pytest.mark.respx(base_url=base_url)
239269
def test_run_with_invalid_cog_version(self, respx_mock: MockRouter) -> None:
@@ -242,7 +272,7 @@ def test_run_with_invalid_cog_version(self, respx_mock: MockRouter) -> None:
242272
respx_mock.post("/predictions").mock(return_value=httpx.Response(400, json={"detail": "Invalid Cog version"}))
243273

244274
with pytest.raises(BadRequestError):
245-
self.client.run("model-with-invalid-cog", input={"prompt": "test prompt"})
275+
self.client.run("invalid/cog:model", input={"prompt": "test prompt"})
246276

247277
@pytest.mark.respx(base_url=base_url)
248278
def test_run_with_model_object(self, respx_mock: MockRouter) -> None:
@@ -274,9 +304,7 @@ def test_run_with_model_version_identifier(self, respx_mock: MockRouter) -> None
274304
# Case where version ID is provided
275305
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
276306

277-
identifier = ModelVersionIdentifier(
278-
owner="test-owner", name="test-model", version="test-version-id"
279-
)
307+
identifier = ModelVersionIdentifier(owner="test-owner", name="test-model", version="test-version-id")
280308
output = self.client.run(identifier, input={"prompt": "test prompt"})
281309

282310
assert output == "test output"
@@ -307,7 +335,7 @@ def test_run_with_file_output_iterator(self, respx_mock: MockRouter) -> None:
307335
)
308336

309337
output: list[FileOutput] = self.client.run(
310-
"some-model-ref", use_file_output=True, wait=False, input={"prompt": "generate file iterator"}
338+
self.model_ref, use_file_output=True, wait=False, input={"prompt": "generate file iterator"}
311339
)
312340

313341
assert isinstance(output, list)
@@ -319,12 +347,15 @@ def test_run_with_file_output_iterator(self, respx_mock: MockRouter) -> None:
319347
class TestAsyncRun:
320348
client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
321349

350+
# Common model reference format that will work with the new SDK
351+
model_ref = "owner/name:version"
352+
322353
@pytest.mark.respx(base_url=base_url)
323354
async def test_async_run_basic(self, respx_mock: MockRouter) -> None:
324355
"""Test basic async model run functionality."""
325356
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
326357

327-
output: Any = await self.client.run("some-model-ref", input={"prompt": "test prompt"})
358+
output: Any = await self.client.run(self.model_ref, input={"prompt": "test prompt"})
328359

329360
assert output == "test output"
330361

@@ -333,7 +364,7 @@ async def test_async_run_with_wait_true(self, respx_mock: MockRouter) -> None:
333364
"""Test async run with wait=True parameter."""
334365
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
335366

336-
output: Any = await self.client.run("some-model-ref", wait=True, input={"prompt": "test prompt"})
367+
output: Any = await self.client.run(self.model_ref, wait=True, input={"prompt": "test prompt"})
337368

338369
assert output == "test output"
339370

@@ -342,7 +373,7 @@ async def test_async_run_with_wait_int(self, respx_mock: MockRouter) -> None:
342373
"""Test async run with wait as an integer value."""
343374
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
344375

345-
output: Any = await self.client.run("some-model-ref", wait=10, input={"prompt": "test prompt"})
376+
output: Any = await self.client.run(self.model_ref, wait=10, input={"prompt": "test prompt"})
346377

347378
assert output == "test output"
348379

@@ -359,7 +390,7 @@ async def test_async_run_without_wait(self, respx_mock: MockRouter) -> None:
359390
return_value=httpx.Response(200, json=create_mock_prediction(status="succeeded"))
360391
)
361392

362-
output: Any = await self.client.run("some-model-ref", wait=False, input={"prompt": "test prompt"})
393+
output: Any = await self.client.run(self.model_ref, wait=False, input={"prompt": "test prompt"})
363394

364395
assert output == "test output"
365396

@@ -372,11 +403,41 @@ async def test_async_run_with_file_output(self, respx_mock: MockRouter) -> None:
372403
return_value=httpx.Response(201, json=create_mock_prediction(output=file_url))
373404
)
374405

375-
output: Any = await self.client.run("some-model-ref", input={"prompt": "generate image"})
406+
output: Any = await self.client.run(self.model_ref, input={"prompt": "generate image"})
376407

377408
assert isinstance(output, AsyncFileOutput)
378409
assert output.url == file_url
379410

411+
@pytest.mark.respx(base_url=base_url)
412+
async def test_async_run_with_data_uri_output(self, respx_mock: MockRouter) -> None:
413+
"""Test async run with data URI output."""
414+
# Create a data URI for a small PNG image (1x1 transparent pixel)
415+
data_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
416+
417+
# Mock prediction with data URI output
418+
respx_mock.post("/predictions").mock(
419+
return_value=httpx.Response(201, json=create_mock_prediction(output=data_uri))
420+
)
421+
422+
# Use a valid model version ID format
423+
output: Any = await self.client.run("owner/name:version", input={"prompt": "generate small image"})
424+
425+
assert isinstance(output, AsyncFileOutput)
426+
assert output.url == data_uri
427+
428+
# Test that we can read the data asynchronously
429+
image_data = await output.read()
430+
assert isinstance(image_data, bytes)
431+
assert len(image_data) > 0
432+
433+
# Test that we can iterate over the data asynchronously
434+
chunks: List[Any] = []
435+
async for chunk in output:
436+
chunks.append(chunk)
437+
438+
assert len(chunks) == 1
439+
assert chunks[0] == image_data
440+
380441
@pytest.mark.respx(base_url=base_url)
381442
async def test_async_run_with_file_list_output(self, respx_mock: MockRouter) -> None:
382443
"""Test async run with list of file outputs."""
@@ -389,7 +450,7 @@ async def test_async_run_with_file_list_output(self, respx_mock: MockRouter) ->
389450
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=mock_prediction))
390451

391452
output: list[AsyncFileOutput] = await self.client.run(
392-
"some-model-ref", input={"prompt": "generate multiple images"}
453+
self.model_ref, input={"prompt": "generate multiple images"}
393454
)
394455

395456
assert isinstance(output, list)
@@ -409,7 +470,7 @@ async def test_async_run_with_dict_file_output(self, respx_mock: MockRouter) ->
409470
)
410471

411472
output: Dict[str, AsyncFileOutput] = await self.client.run(
412-
"some-model-ref", input={"prompt": "structured output"}
473+
self.model_ref, input={"prompt": "structured output"}
413474
)
414475

415476
assert isinstance(output, dict)
@@ -425,7 +486,7 @@ async def test_async_run_with_error(self, respx_mock: MockRouter) -> None:
425486
)
426487

427488
with pytest.raises(ModelError):
428-
await self.client.run("error-model-ref", input={"prompt": "trigger error"})
489+
await self.client.run(self.model_ref, input={"prompt": "trigger error"})
429490

430491
@pytest.mark.respx(base_url=base_url)
431492
async def test_async_run_with_base64_file(self, respx_mock: MockRouter) -> None:
@@ -436,14 +497,14 @@ async def test_async_run_with_base64_file(self, respx_mock: MockRouter) -> None:
436497
# Mock the prediction response
437498
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
438499

439-
output: Any = await self.client.run("some-model-ref", input={"file": file_obj}, file_encoding_strategy="base64")
500+
output: Any = await self.client.run(self.model_ref, input={"file": file_obj}, file_encoding_strategy="base64")
440501

441502
assert output == "test output"
442503

443504
async def test_async_run_with_prefer_conflict(self) -> None:
444505
"""Test async run with conflicting wait and prefer parameters."""
445506
with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):
446-
await self.client.run("some-model-ref", wait=True, prefer="nowait", input={"prompt": "test"})
507+
await self.client.run(self.model_ref, wait=True, prefer="nowait", input={"prompt": "test"})
447508

448509
@pytest.mark.respx(base_url=base_url)
449510
async def test_async_run_with_iterator(self, respx_mock: MockRouter) -> None:
@@ -454,7 +515,7 @@ async def test_async_run_with_iterator(self, respx_mock: MockRouter) -> None:
454515
return_value=httpx.Response(201, json=create_mock_prediction(output=output_iterator))
455516
)
456517

457-
output: list[str] = await self.client.run("some-model-ref", input={"prompt": "generate iterator"})
518+
output: list[str] = await self.client.run(self.model_ref, input={"prompt": "generate iterator"})
458519

459520
assert isinstance(output, list)
460521
assert len(output) == 3
@@ -467,7 +528,7 @@ async def test_async_run_with_invalid_identifier(self, respx_mock: MockRouter) -
467528
respx_mock.post("/predictions").mock(return_value=httpx.Response(404, json={"detail": "Model not found"}))
468529

469530
with pytest.raises(NotFoundError):
470-
await self.client.run("invalid-model-ref", input={"prompt": "test prompt"})
531+
await self.client.run("invalid/model:ref", input={"prompt": "test prompt"})
471532

472533
@pytest.mark.respx(base_url=base_url)
473534
async def test_async_run_with_invalid_cog_version(self, respx_mock: MockRouter) -> None:
@@ -476,7 +537,7 @@ async def test_async_run_with_invalid_cog_version(self, respx_mock: MockRouter)
476537
respx_mock.post("/predictions").mock(return_value=httpx.Response(400, json={"detail": "Invalid Cog version"}))
477538

478539
with pytest.raises(BadRequestError):
479-
await self.client.run("model-with-invalid-cog", input={"prompt": "test prompt"})
540+
await self.client.run("invalid/cog:model", input={"prompt": "test prompt"})
480541

481542
@pytest.mark.respx(base_url=base_url)
482543
async def test_async_run_with_model_object(self, respx_mock: MockRouter) -> None:
@@ -508,9 +569,7 @@ async def test_async_run_with_model_version_identifier(self, respx_mock: MockRou
508569
# Case where version ID is provided
509570
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
510571

511-
identifier = ModelVersionIdentifier(
512-
owner="test-owner", name="test-model", version="test-version-id"
513-
)
572+
identifier = ModelVersionIdentifier(owner="test-owner", name="test-model", version="test-version-id")
514573
output = await self.client.run(identifier, input={"prompt": "test prompt"})
515574

516575
assert output == "test output"
@@ -541,10 +600,34 @@ async def test_async_run_with_file_output_iterator(self, respx_mock: MockRouter)
541600
)
542601

543602
output: list[AsyncFileOutput] = await self.client.run(
544-
"some-model-ref", use_file_output=True, wait=False, input={"prompt": "generate file iterator"}
603+
self.model_ref, use_file_output=True, wait=False, input={"prompt": "generate file iterator"}
545604
)
546605

547606
assert isinstance(output, list)
548607
assert len(output) == 3
549608
assert all(isinstance(item, AsyncFileOutput) for item in output)
550609
assert [item.url for item in output] == file_urls
610+
611+
@pytest.mark.respx(base_url=base_url)
612+
async def test_async_run_concurrently(self, respx_mock: MockRouter) -> None:
613+
"""Test running multiple models concurrently with asyncio."""
614+
import asyncio
615+
616+
# Mock three different prediction responses
617+
mock_outputs = ["output1", "output2", "output3"]
618+
prompts = ["prompt1", "prompt2", "prompt3"]
619+
620+
# Set up mocks for each request (using side_effect to allow multiple matches)
621+
# Note: This will match any POST to /predictions but return different responses
622+
route = respx_mock.post("/predictions")
623+
route.side_effect = [httpx.Response(201, json=create_mock_prediction(output=output)) for output in mock_outputs]
624+
625+
# Run three predictions concurrently
626+
tasks = [self.client.run("owner/name:version", input={"prompt": prompt}) for prompt in prompts]
627+
628+
results = await asyncio.gather(*tasks)
629+
630+
# Verify each result matches expected output
631+
assert len(results) == 3
632+
for i, result in enumerate(results):
633+
assert result == mock_outputs[i]

0 commit comments

Comments
 (0)