Skip to content

Commit c6f76bb

Browse files
committed
Add run() tests
1 parent ffdaeda commit c6f76bb

File tree

3 files changed

+301
-1
lines changed

3 files changed

+301
-1
lines changed

src/replicate/_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from . import _exceptions
1414
from ._qs import Querystring
15-
from .types import PredictionCreateParams
1615
from ._types import (
1716
NOT_GIVEN,
1817
Omit,

tests/lib/__init__.py

Whitespace-only changes.

tests/lib/test_run.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
from __future__ import annotations
2+
3+
import io
4+
import os
5+
from typing import Any, Dict, Optional
6+
7+
import httpx
8+
import pytest
9+
from respx import MockRouter
10+
11+
from replicate import ReplicateClient, AsyncReplicateClient
12+
from replicate.lib._files import FileOutput, AsyncFileOutput
13+
from replicate._exceptions import ModelError
14+
15+
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
16+
bearer_token = "My Bearer Token"
17+
18+
19+
# Mock prediction data for testing
20+
def create_mock_prediction(
21+
status: str = "succeeded", output: Any = "test output", error: Optional[str] = None
22+
) -> Dict[str, Any]:
23+
return {
24+
"id": "test_prediction_id",
25+
"version": "test_version",
26+
"status": status,
27+
"input": {"prompt": "test prompt"},
28+
"output": output,
29+
"error": error,
30+
"created_at": "2023-01-01T00:00:00Z",
31+
"started_at": "2023-01-01T00:00:01Z",
32+
"completed_at": "2023-01-01T00:00:02Z" if status in ["succeeded", "failed"] else None,
33+
"urls": {
34+
"get": "https://api.replicate.com/v1/predictions/test_prediction_id",
35+
"cancel": "https://api.replicate.com/v1/predictions/test_prediction_id/cancel",
36+
},
37+
"model": "test-model",
38+
"data_removed": False,
39+
}
40+
41+
42+
class TestRun:
43+
client = ReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
44+
45+
@pytest.mark.respx(base_url=base_url)
46+
def test_run_basic(self, respx_mock: MockRouter) -> None:
47+
"""Test basic model run functionality."""
48+
# Mock the prediction creation
49+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
50+
51+
output: Any = self.client.run("some-model-ref", input={"prompt": "test prompt"})
52+
53+
assert output == "test output"
54+
55+
@pytest.mark.respx(base_url=base_url)
56+
def test_run_with_wait_true(self, respx_mock: MockRouter) -> None:
57+
"""Test run with wait=True parameter."""
58+
# Mock the prediction creation
59+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
60+
61+
output: Any = self.client.run("some-model-ref", wait=True, input={"prompt": "test prompt"})
62+
63+
assert output == "test output"
64+
65+
@pytest.mark.respx(base_url=base_url)
66+
def test_run_with_wait_int(self, respx_mock: MockRouter) -> None:
67+
"""Test run with wait as an integer value."""
68+
# Mock the prediction creation
69+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
70+
71+
output: Any = self.client.run("some-model-ref", wait=10, input={"prompt": "test prompt"})
72+
73+
assert output == "test output"
74+
75+
@pytest.mark.respx(base_url=base_url)
76+
def test_run_without_wait(self, respx_mock: MockRouter) -> None:
77+
"""Test run with wait=False parameter."""
78+
# Initial prediction state is "processing"
79+
respx_mock.post("/predictions").mock(
80+
return_value=httpx.Response(201, json=create_mock_prediction(status="processing"))
81+
)
82+
83+
# When we wait for it, it becomes "succeeded"
84+
respx_mock.get("/predictions/test_prediction_id").mock(
85+
return_value=httpx.Response(200, json=create_mock_prediction(status="succeeded"))
86+
)
87+
88+
output: Any = self.client.run("some-model-ref", wait=False, input={"prompt": "test prompt"})
89+
90+
assert output == "test output"
91+
92+
@pytest.mark.respx(base_url=base_url, assert_all_mocked=False)
93+
def test_run_with_file_output(self, respx_mock: MockRouter) -> None:
94+
"""Test run with file output."""
95+
# Mock prediction with file URL output
96+
file_url = "https://replicate.delivery/output.png"
97+
respx_mock.post("/predictions").mock(
98+
return_value=httpx.Response(201, json=create_mock_prediction(output=file_url))
99+
)
100+
101+
output: Any = self.client.run("some-model-ref", input={"prompt": "generate image"})
102+
103+
assert isinstance(output, FileOutput)
104+
assert output.url == file_url
105+
106+
@pytest.mark.respx(base_url=base_url)
107+
def test_run_with_file_list_output(self, respx_mock: MockRouter) -> None:
108+
"""Test run with list of file outputs."""
109+
# Create a mock prediction response with a list of file URLs
110+
file_urls = ["https://replicate.delivery/output1.png", "https://replicate.delivery/output2.png"]
111+
mock_prediction = create_mock_prediction()
112+
mock_prediction["output"] = file_urls
113+
114+
# Mock the endpoint
115+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=mock_prediction))
116+
117+
output: list[FileOutput] = self.client.run("some-model-ref", input={"prompt": "generate multiple images"}) # type: ignore
118+
119+
assert isinstance(output, list)
120+
assert len(output) == 2
121+
assert all(isinstance(item, FileOutput) for item in output)
122+
123+
@pytest.mark.respx(base_url=base_url)
124+
def test_run_with_dict_file_output(self, respx_mock: MockRouter) -> None:
125+
"""Test run with dictionary of file outputs."""
126+
# Mock prediction with dict of file URLs
127+
file_urls = {
128+
"image1": "https://replicate.delivery/output1.png",
129+
"image2": "https://replicate.delivery/output2.png",
130+
}
131+
respx_mock.post("/predictions").mock(
132+
return_value=httpx.Response(201, json=create_mock_prediction(output=file_urls))
133+
)
134+
135+
output: Dict[str, FileOutput] = self.client.run("some-model-ref", input={"prompt": "structured output"}) # type: ignore
136+
137+
assert isinstance(output, dict)
138+
assert len(output) == 2
139+
assert all(isinstance(item, FileOutput) for item in output.values())
140+
141+
@pytest.mark.respx(base_url=base_url)
142+
def test_run_with_error(self, respx_mock: MockRouter) -> None:
143+
"""Test run with model error."""
144+
# Mock prediction with error
145+
respx_mock.post("/predictions").mock(
146+
return_value=httpx.Response(201, json=create_mock_prediction(status="failed", error="Model error occurred"))
147+
)
148+
149+
with pytest.raises(ModelError):
150+
self.client.run("error-model-ref", input={"prompt": "trigger error"})
151+
152+
@pytest.mark.respx(base_url=base_url)
153+
def test_run_with_base64_file(self, respx_mock: MockRouter) -> None:
154+
"""Test run with base64 encoded file input."""
155+
# Create a simple file-like object
156+
file_obj = io.BytesIO(b"test content")
157+
158+
# Mock the prediction response
159+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
160+
161+
output: Any = self.client.run("some-model-ref", input={"file": file_obj})
162+
163+
assert output == "test output"
164+
165+
def test_run_with_prefer_conflict(self) -> None:
166+
"""Test run with conflicting wait and prefer parameters."""
167+
with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):
168+
self.client.run("some-model-ref", wait=True, prefer="nowait", input={"prompt": "test"})
169+
170+
171+
class TestAsyncRun:
172+
client = AsyncReplicateClient(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
173+
174+
@pytest.mark.respx(base_url=base_url)
175+
async def test_async_run_basic(self, respx_mock: MockRouter) -> None:
176+
"""Test basic async model run functionality."""
177+
# Mock the prediction creation
178+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
179+
180+
output: Any = await self.client.run("some-model-ref", input={"prompt": "test prompt"})
181+
182+
assert output == "test output"
183+
184+
@pytest.mark.respx(base_url=base_url)
185+
async def test_async_run_with_wait_true(self, respx_mock: MockRouter) -> None:
186+
"""Test async run with wait=True parameter."""
187+
# Mock the prediction creation
188+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
189+
190+
output: Any = await self.client.run("some-model-ref", wait=True, input={"prompt": "test prompt"})
191+
192+
assert output == "test output"
193+
194+
@pytest.mark.respx(base_url=base_url)
195+
async def test_async_run_with_wait_int(self, respx_mock: MockRouter) -> None:
196+
"""Test async run with wait as an integer value."""
197+
# Mock the prediction creation
198+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
199+
200+
output: Any = await self.client.run("some-model-ref", wait=10, input={"prompt": "test prompt"})
201+
202+
assert output == "test output"
203+
204+
@pytest.mark.respx(base_url=base_url)
205+
async def test_async_run_without_wait(self, respx_mock: MockRouter) -> None:
206+
"""Test async run with wait=False parameter."""
207+
# Initial prediction state is "processing"
208+
respx_mock.post("/predictions").mock(
209+
return_value=httpx.Response(201, json=create_mock_prediction(status="processing"))
210+
)
211+
212+
# When we wait for it, it becomes "succeeded"
213+
respx_mock.get("/predictions/test_prediction_id").mock(
214+
return_value=httpx.Response(200, json=create_mock_prediction(status="succeeded"))
215+
)
216+
217+
output: Any = await self.client.run("some-model-ref", wait=False, input={"prompt": "test prompt"})
218+
219+
assert output == "test output"
220+
221+
@pytest.mark.respx(base_url=base_url, assert_all_mocked=False)
222+
async def test_async_run_with_file_output(self, respx_mock: MockRouter) -> None:
223+
"""Test async run with file output."""
224+
# Mock prediction with file URL output
225+
file_url = "https://replicate.delivery/output.png"
226+
respx_mock.post("/predictions").mock(
227+
return_value=httpx.Response(201, json=create_mock_prediction(output=file_url))
228+
)
229+
230+
output: Any = await self.client.run("some-model-ref", input={"prompt": "generate image"})
231+
232+
assert isinstance(output, AsyncFileOutput)
233+
assert output.url == file_url
234+
235+
@pytest.mark.respx(base_url=base_url)
236+
async def test_async_run_with_file_list_output(self, respx_mock: MockRouter) -> None:
237+
"""Test async run with list of file outputs."""
238+
# Create a mock prediction response with a list of file URLs
239+
file_urls = ["https://replicate.delivery/output1.png", "https://replicate.delivery/output2.png"]
240+
mock_prediction = create_mock_prediction()
241+
mock_prediction["output"] = file_urls
242+
243+
# Mock the endpoint
244+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=mock_prediction))
245+
246+
output: list[AsyncFileOutput] = await self.client.run(
247+
"some-model-ref", input={"prompt": "generate multiple images"}
248+
) # type: ignore
249+
250+
assert isinstance(output, list)
251+
assert len(output) == 2
252+
assert all(isinstance(item, AsyncFileOutput) for item in output)
253+
254+
@pytest.mark.respx(base_url=base_url)
255+
async def test_async_run_with_dict_file_output(self, respx_mock: MockRouter) -> None:
256+
"""Test async run with dictionary of file outputs."""
257+
# Mock prediction with dict of file URLs
258+
file_urls = {
259+
"image1": "https://replicate.delivery/output1.png",
260+
"image2": "https://replicate.delivery/output2.png",
261+
}
262+
respx_mock.post("/predictions").mock(
263+
return_value=httpx.Response(201, json=create_mock_prediction(output=file_urls))
264+
)
265+
266+
output: Dict[str, AsyncFileOutput] = await self.client.run(
267+
"some-model-ref", input={"prompt": "structured output"}
268+
) # type: ignore
269+
270+
assert isinstance(output, dict)
271+
assert len(output) == 2
272+
assert all(isinstance(item, AsyncFileOutput) for item in output.values())
273+
274+
@pytest.mark.respx(base_url=base_url)
275+
async def test_async_run_with_error(self, respx_mock: MockRouter) -> None:
276+
"""Test async run with model error."""
277+
# Mock prediction with error
278+
respx_mock.post("/predictions").mock(
279+
return_value=httpx.Response(201, json=create_mock_prediction(status="failed", error="Model error occurred"))
280+
)
281+
282+
with pytest.raises(ModelError):
283+
await self.client.run("error-model-ref", input={"prompt": "trigger error"})
284+
285+
@pytest.mark.respx(base_url=base_url)
286+
async def test_async_run_with_base64_file(self, respx_mock: MockRouter) -> None:
287+
"""Test async run with base64 encoded file input."""
288+
# Create a simple file-like object
289+
file_obj = io.BytesIO(b"test content")
290+
291+
# Mock the prediction response
292+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
293+
294+
output: Any = await self.client.run("some-model-ref", input={"file": file_obj})
295+
296+
assert output == "test output"
297+
298+
async def test_async_run_with_prefer_conflict(self) -> None:
299+
"""Test async run with conflicting wait and prefer parameters."""
300+
with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):
301+
await self.client.run("some-model-ref", wait=True, prefer="nowait", input={"prompt": "test"})

0 commit comments

Comments
 (0)