Skip to content

Commit 2143731

Browse files
committed
add test for file upload
1 parent 0ec2897 commit 2143731

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

tests/lib/test_run.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from respx import MockRouter
1111

1212
from replicate import Replicate, AsyncReplicate
13+
from replicate._compat import model_dump
1314
from replicate.lib._files import FileOutput, AsyncFileOutput
1415
from replicate._exceptions import ModelError, NotFoundError, BadRequestError
1516
from replicate.lib._models import Model, Version, ModelVersionIdentifier
17+
from replicate.types.file_create_response import URLs, Checksums, FileCreateResponse
1618

1719
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
1820
bearer_token = "My Bearer Token"
@@ -89,6 +91,16 @@ class TestRun:
8991

9092
# Common model reference format that will work with the new SDK
9193
model_ref = "owner/name:version"
94+
file_create_response = FileCreateResponse(
95+
id="test_file_id",
96+
checksums=Checksums(sha256="test_sha256"),
97+
content_type="application/octet-stream",
98+
created_at=datetime.datetime.now(),
99+
expires_at=datetime.datetime.now() + datetime.timedelta(days=1),
100+
metadata={},
101+
size=1234,
102+
urls=URLs(get="https://api.replicate.com/v1/files/test_file_id"),
103+
)
92104

93105
@pytest.mark.respx(base_url=base_url)
94106
def test_run_basic(self, respx_mock: MockRouter) -> None:
@@ -236,6 +248,23 @@ def test_run_with_base64_file(self, respx_mock: MockRouter) -> None:
236248

237249
assert output == "test output"
238250

251+
@pytest.mark.respx(base_url=base_url)
252+
def test_run_with_file_upload(self, respx_mock: MockRouter) -> None:
253+
"""Test run with base64 encoded file input."""
254+
# Create a simple file-like object
255+
file_obj = io.BytesIO(b"test content")
256+
257+
# Mock the prediction response
258+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
259+
# Mock the file upload endpoint
260+
respx_mock.post("/files").mock(
261+
return_value=httpx.Response(201, json=model_dump(self.file_create_response, mode="json"))
262+
)
263+
264+
output: Any = self.client.run(self.model_ref, input={"file": file_obj})
265+
266+
assert output == "test output"
267+
239268
def test_run_with_prefer_conflict(self) -> None:
240269
"""Test run with conflicting wait and prefer parameters."""
241270
with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):
@@ -349,6 +378,16 @@ class TestAsyncRun:
349378

350379
# Common model reference format that will work with the new SDK
351380
model_ref = "owner/name:version"
381+
file_create_response = FileCreateResponse(
382+
id="test_file_id",
383+
checksums=Checksums(sha256="test_sha256"),
384+
content_type="application/octet-stream",
385+
created_at=datetime.datetime.now(),
386+
expires_at=datetime.datetime.now() + datetime.timedelta(days=1),
387+
metadata={},
388+
size=1234,
389+
urls=URLs(get="https://api.replicate.com/v1/files/test_file_id"),
390+
)
352391

353392
@pytest.mark.respx(base_url=base_url)
354393
async def test_async_run_basic(self, respx_mock: MockRouter) -> None:
@@ -501,6 +540,23 @@ async def test_async_run_with_base64_file(self, respx_mock: MockRouter) -> None:
501540

502541
assert output == "test output"
503542

543+
@pytest.mark.respx(base_url=base_url)
544+
async def test_async_run_with_file_upload(self, respx_mock: MockRouter) -> None:
545+
"""Test async run with base64 encoded file input."""
546+
# Create a simple file-like object
547+
file_obj = io.BytesIO(b"test content")
548+
549+
# Mock the prediction response
550+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
551+
# Mock the file upload endpoint
552+
respx_mock.post("/files").mock(
553+
return_value=httpx.Response(201, json=model_dump(self.file_create_response, mode="json"))
554+
)
555+
556+
output: Any = await self.client.run(self.model_ref, input={"file": file_obj})
557+
558+
assert output == "test output"
559+
504560
async def test_async_run_with_prefer_conflict(self) -> None:
505561
"""Test async run with conflicting wait and prefer parameters."""
506562
with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):

0 commit comments

Comments
 (0)