Skip to content

Commit 4a25c43

Browse files
committed
Add run() tests
1 parent c6f76bb commit 4a25c43

File tree

1 file changed

+144
-12
lines changed

1 file changed

+144
-12
lines changed

tests/lib/test_run.py

Lines changed: 144 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,38 @@
1010

1111
from replicate import ReplicateClient, AsyncReplicateClient
1212
from replicate.lib._files import FileOutput, AsyncFileOutput
13-
from replicate._exceptions import ModelError
13+
from replicate._exceptions import ModelError, NotFoundError, BadRequestError
1414

1515
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
1616
bearer_token = "My Bearer Token"
1717

1818

1919
# Mock prediction data for testing
2020
def create_mock_prediction(
21-
status: str = "succeeded", output: Any = "test output", error: Optional[str] = None
21+
status: str = "succeeded",
22+
output: Any = "test output",
23+
error: Optional[str] = None,
24+
logs: Optional[str] = None,
25+
urls: Optional[Dict[str, str]] = None,
2226
) -> Dict[str, Any]:
27+
if urls is None:
28+
urls = {
29+
"get": "https://api.replicate.com/v1/predictions/test_prediction_id",
30+
"cancel": "https://api.replicate.com/v1/predictions/test_prediction_id/cancel",
31+
}
32+
2333
return {
2434
"id": "test_prediction_id",
2535
"version": "test_version",
2636
"status": status,
2737
"input": {"prompt": "test prompt"},
2838
"output": output,
2939
"error": error,
40+
"logs": logs,
3041
"created_at": "2023-01-01T00:00:00Z",
3142
"started_at": "2023-01-01T00:00:01Z",
3243
"completed_at": "2023-01-01T00:00:02Z" if status in ["succeeded", "failed"] else None,
33-
"urls": {
34-
"get": "https://api.replicate.com/v1/predictions/test_prediction_id",
35-
"cancel": "https://api.replicate.com/v1/predictions/test_prediction_id/cancel",
36-
},
44+
"urls": urls,
3745
"model": "test-model",
3846
"data_removed": False,
3947
}
@@ -45,7 +53,6 @@ class TestRun:
4553
@pytest.mark.respx(base_url=base_url)
4654
def test_run_basic(self, respx_mock: MockRouter) -> None:
4755
"""Test basic model run functionality."""
48-
# Mock the prediction creation
4956
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
5057

5158
output: Any = self.client.run("some-model-ref", input={"prompt": "test prompt"})
@@ -55,7 +62,6 @@ def test_run_basic(self, respx_mock: MockRouter) -> None:
5562
@pytest.mark.respx(base_url=base_url)
5663
def test_run_with_wait_true(self, respx_mock: MockRouter) -> None:
5764
"""Test run with wait=True parameter."""
58-
# Mock the prediction creation
5965
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
6066

6167
output: Any = self.client.run("some-model-ref", wait=True, input={"prompt": "test prompt"})
@@ -65,7 +71,6 @@ def test_run_with_wait_true(self, respx_mock: MockRouter) -> None:
6571
@pytest.mark.respx(base_url=base_url)
6672
def test_run_with_wait_int(self, respx_mock: MockRouter) -> None:
6773
"""Test run with wait as an integer value."""
68-
# Mock the prediction creation
6974
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
7075

7176
output: Any = self.client.run("some-model-ref", wait=10, input={"prompt": "test prompt"})
@@ -167,14 +172,78 @@ def test_run_with_prefer_conflict(self) -> None:
167172
with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):
168173
self.client.run("some-model-ref", wait=True, prefer="nowait", input={"prompt": "test"})
169174

175+
@pytest.mark.respx(base_url=base_url)
176+
def test_run_with_iterator(self, respx_mock: MockRouter) -> None:
177+
"""Test run with an iterator output."""
178+
# Create a mock prediction with an iterator output
179+
output_iterator = ["chunk1", "chunk2", "chunk3"]
180+
respx_mock.post("/predictions").mock(
181+
return_value=httpx.Response(201, json=create_mock_prediction(output=output_iterator))
182+
)
183+
184+
output = self.client.run("some-model-ref", input={"prompt": "generate iterator"})
185+
186+
assert isinstance(output, list)
187+
assert len(output) == 3
188+
assert output == output_iterator
189+
190+
@pytest.mark.respx(base_url=base_url)
191+
def test_run_with_invalid_identifier(self, respx_mock: MockRouter) -> None:
192+
"""Test run with an invalid model identifier."""
193+
# Mock a 404 response for an invalid model identifier
194+
respx_mock.post("/predictions").mock(return_value=httpx.Response(404, json={"detail": "Model not found"}))
195+
196+
with pytest.raises(NotFoundError):
197+
self.client.run("invalid-model-ref", input={"prompt": "test prompt"})
198+
199+
@pytest.mark.respx(base_url=base_url)
200+
def test_run_with_invalid_cog_version(self, respx_mock: MockRouter) -> None:
201+
"""Test run with an invalid Cog version."""
202+
# Mock an error response for an invalid Cog version
203+
respx_mock.post("/predictions").mock(return_value=httpx.Response(400, json={"detail": "Invalid Cog version"}))
204+
205+
with pytest.raises(BadRequestError):
206+
self.client.run("model-with-invalid-cog", input={"prompt": "test prompt"})
207+
208+
@pytest.mark.respx(base_url=base_url)
209+
def test_run_with_file_output_iterator(self, respx_mock: MockRouter) -> None:
210+
"""Test run with file output iterator."""
211+
# Mock URLs for file outputs
212+
file_urls = [
213+
"https://replicate.delivery/output1.png",
214+
"https://replicate.delivery/output2.png",
215+
"https://replicate.delivery/output3.png",
216+
]
217+
218+
# Initial response with processing status and no output
219+
respx_mock.post("/predictions").mock(
220+
return_value=httpx.Response(201, json=create_mock_prediction(status="processing", output=None))
221+
)
222+
223+
# First poll returns still processing
224+
respx_mock.get("/predictions/test_prediction_id").mock(
225+
return_value=httpx.Response(200, json=create_mock_prediction(status="processing", output=None))
226+
)
227+
228+
# Second poll returns success with file URLs
229+
respx_mock.get("/predictions/test_prediction_id").mock(
230+
return_value=httpx.Response(200, json=create_mock_prediction(output=file_urls))
231+
)
232+
233+
output = self.client.run("some-model-ref", input={"prompt": "generate file iterator"})
234+
235+
assert isinstance(output, list)
236+
assert len(output) == 3
237+
assert all(isinstance(item, FileOutput) for item in output)
238+
assert [item.url for item in output] == file_urls
239+
170240

171241
class TestAsyncRun:
172242
client = AsyncReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
173243

174244
@pytest.mark.respx(base_url=base_url)
175245
async def test_async_run_basic(self, respx_mock: MockRouter) -> None:
176246
"""Test basic async model run functionality."""
177-
# Mock the prediction creation
178247
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
179248

180249
output: Any = await self.client.run("some-model-ref", input={"prompt": "test prompt"})
@@ -184,7 +253,6 @@ async def test_async_run_basic(self, respx_mock: MockRouter) -> None:
184253
@pytest.mark.respx(base_url=base_url)
185254
async def test_async_run_with_wait_true(self, respx_mock: MockRouter) -> None:
186255
"""Test async run with wait=True parameter."""
187-
# Mock the prediction creation
188256
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
189257

190258
output: Any = await self.client.run("some-model-ref", wait=True, input={"prompt": "test prompt"})
@@ -194,7 +262,6 @@ async def test_async_run_with_wait_true(self, respx_mock: MockRouter) -> None:
194262
@pytest.mark.respx(base_url=base_url)
195263
async def test_async_run_with_wait_int(self, respx_mock: MockRouter) -> None:
196264
"""Test async run with wait as an integer value."""
197-
# Mock the prediction creation
198265
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
199266

200267
output: Any = await self.client.run("some-model-ref", wait=10, input={"prompt": "test prompt"})
@@ -299,3 +366,68 @@ async def test_async_run_with_prefer_conflict(self) -> None:
299366
"""Test async run with conflicting wait and prefer parameters."""
300367
with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):
301368
await self.client.run("some-model-ref", wait=True, prefer="nowait", input={"prompt": "test"})
369+
370+
@pytest.mark.respx(base_url=base_url)
371+
async def test_async_run_with_iterator(self, respx_mock: MockRouter) -> None:
372+
"""Test async run with an iterator output."""
373+
# Create a mock prediction with an iterator output
374+
output_iterator = ["chunk1", "chunk2", "chunk3"]
375+
respx_mock.post("/predictions").mock(
376+
return_value=httpx.Response(201, json=create_mock_prediction(output=output_iterator))
377+
)
378+
379+
output = await self.client.run("some-model-ref", input={"prompt": "generate iterator"})
380+
381+
assert isinstance(output, list)
382+
assert len(output) == 3
383+
assert output == output_iterator
384+
385+
@pytest.mark.respx(base_url=base_url)
386+
async def test_async_run_with_invalid_identifier(self, respx_mock: MockRouter) -> None:
387+
"""Test async run with an invalid model identifier."""
388+
# Mock a 404 response for an invalid model identifier
389+
respx_mock.post("/predictions").mock(return_value=httpx.Response(404, json={"detail": "Model not found"}))
390+
391+
with pytest.raises(NotFoundError):
392+
await self.client.run("invalid-model-ref", input={"prompt": "test prompt"})
393+
394+
@pytest.mark.respx(base_url=base_url)
395+
async def test_async_run_with_invalid_cog_version(self, respx_mock: MockRouter) -> None:
396+
"""Test async run with an invalid Cog version."""
397+
# Mock an error response for an invalid Cog version
398+
respx_mock.post("/predictions").mock(return_value=httpx.Response(400, json={"detail": "Invalid Cog version"}))
399+
400+
with pytest.raises(BadRequestError):
401+
await self.client.run("model-with-invalid-cog", input={"prompt": "test prompt"})
402+
403+
@pytest.mark.respx(base_url=base_url)
404+
async def test_async_run_with_file_output_iterator(self, respx_mock: MockRouter) -> None:
405+
"""Test async run with file output iterator."""
406+
# Mock URLs for file outputs
407+
file_urls = [
408+
"https://replicate.delivery/output1.png",
409+
"https://replicate.delivery/output2.png",
410+
"https://replicate.delivery/output3.png",
411+
]
412+
413+
# Initial response with processing status and no output
414+
respx_mock.post("/predictions").mock(
415+
return_value=httpx.Response(201, json=create_mock_prediction(status="processing", output=None))
416+
)
417+
418+
# First poll returns still processing
419+
respx_mock.get("/predictions/test_prediction_id").mock(
420+
return_value=httpx.Response(200, json=create_mock_prediction(status="processing", output=None))
421+
)
422+
423+
# Second poll returns success with file URLs
424+
respx_mock.get("/predictions/test_prediction_id").mock(
425+
return_value=httpx.Response(200, json=create_mock_prediction(output=file_urls))
426+
)
427+
428+
output = await self.client.run("some-model-ref", input={"prompt": "generate file iterator"})
429+
430+
assert isinstance(output, list)
431+
assert len(output) == 3
432+
assert all(isinstance(item, AsyncFileOutput) for item in output)
433+
assert [item.url for item in output] == file_urls

0 commit comments

Comments
 (0)