|
10 | 10 | from respx import MockRouter |
11 | 11 |
|
12 | 12 | from replicate import Replicate, AsyncReplicate |
| 13 | +from replicate._compat import model_dump |
13 | 14 | from replicate.lib._files import FileOutput, AsyncFileOutput |
14 | 15 | from replicate._exceptions import ModelError, NotFoundError, BadRequestError |
15 | 16 | from replicate.lib._models import Model, Version, ModelVersionIdentifier |
| 17 | +from replicate.types.file_create_response import URLs, Checksums, FileCreateResponse |
16 | 18 |
|
17 | 19 | base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") |
18 | 20 | bearer_token = "My Bearer Token" |
@@ -89,6 +91,16 @@ class TestRun: |
89 | 91 |
|
90 | 92 | # Common model reference format that will work with the new SDK |
91 | 93 | model_ref = "owner/name:version" |
| 94 | + file_create_response = FileCreateResponse( |
| 95 | + id="test_file_id", |
| 96 | + checksums=Checksums(sha256="test_sha256"), |
| 97 | + content_type="application/octet-stream", |
| 98 | + created_at=datetime.datetime.now(), |
| 99 | + expires_at=datetime.datetime.now() + datetime.timedelta(days=1), |
| 100 | + metadata={}, |
| 101 | + size=1234, |
| 102 | + urls=URLs(get="https://api.replicate.com/v1/files/test_file_id"), |
| 103 | + ) |
92 | 104 |
|
93 | 105 | @pytest.mark.respx(base_url=base_url) |
94 | 106 | def test_run_basic(self, respx_mock: MockRouter) -> None: |
@@ -236,6 +248,23 @@ def test_run_with_base64_file(self, respx_mock: MockRouter) -> None: |
236 | 248 |
|
237 | 249 | assert output == "test output" |
238 | 250 |
|
| 251 | + @pytest.mark.respx(base_url=base_url) |
| 252 | + def test_run_with_file_upload(self, respx_mock: MockRouter) -> None: |
| 253 | + """Test run with base64 encoded file input.""" |
| 254 | + # Create a simple file-like object |
| 255 | + file_obj = io.BytesIO(b"test content") |
| 256 | + |
| 257 | + # Mock the prediction response |
| 258 | + respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction())) |
| 259 | + # Mock the file upload endpoint |
| 260 | + respx_mock.post("/files").mock( |
| 261 | + return_value=httpx.Response(201, json=model_dump(self.file_create_response, mode="json")) |
| 262 | + ) |
| 263 | + |
| 264 | + output: Any = self.client.run(self.model_ref, input={"file": file_obj}) |
| 265 | + |
| 266 | + assert output == "test output" |
| 267 | + |
239 | 268 | def test_run_with_prefer_conflict(self) -> None: |
240 | 269 | """Test run with conflicting wait and prefer parameters.""" |
241 | 270 | with pytest.raises(TypeError, match="cannot mix and match prefer and wait"): |
@@ -349,6 +378,16 @@ class TestAsyncRun: |
349 | 378 |
|
350 | 379 | # Common model reference format that will work with the new SDK |
351 | 380 | model_ref = "owner/name:version" |
| 381 | + file_create_response = FileCreateResponse( |
| 382 | + id="test_file_id", |
| 383 | + checksums=Checksums(sha256="test_sha256"), |
| 384 | + content_type="application/octet-stream", |
| 385 | + created_at=datetime.datetime.now(), |
| 386 | + expires_at=datetime.datetime.now() + datetime.timedelta(days=1), |
| 387 | + metadata={}, |
| 388 | + size=1234, |
| 389 | + urls=URLs(get="https://api.replicate.com/v1/files/test_file_id"), |
| 390 | + ) |
352 | 391 |
|
353 | 392 | @pytest.mark.respx(base_url=base_url) |
354 | 393 | async def test_async_run_basic(self, respx_mock: MockRouter) -> None: |
@@ -501,6 +540,23 @@ async def test_async_run_with_base64_file(self, respx_mock: MockRouter) -> None: |
501 | 540 |
|
502 | 541 | assert output == "test output" |
503 | 542 |
|
| 543 | + @pytest.mark.respx(base_url=base_url) |
| 544 | + async def test_async_run_with_file_upload(self, respx_mock: MockRouter) -> None: |
| 545 | + """Test async run with base64 encoded file input.""" |
| 546 | + # Create a simple file-like object |
| 547 | + file_obj = io.BytesIO(b"test content") |
| 548 | + |
| 549 | + # Mock the prediction response |
| 550 | + respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction())) |
| 551 | + # Mock the file upload endpoint |
| 552 | + respx_mock.post("/files").mock( |
| 553 | + return_value=httpx.Response(201, json=model_dump(self.file_create_response, mode="json")) |
| 554 | + ) |
| 555 | + |
| 556 | + output: Any = await self.client.run(self.model_ref, input={"file": file_obj}) |
| 557 | + |
| 558 | + assert output == "test output" |
| 559 | + |
504 | 560 | async def test_async_run_with_prefer_conflict(self) -> None: |
505 | 561 | """Test async run with conflicting wait and prefer parameters.""" |
506 | 562 | with pytest.raises(TypeError, match="cannot mix and match prefer and wait"): |
|
0 commit comments