Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/replicate/lib/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ def encode_json(
if file_encoding_strategy == "base64":
return base64_encode_file(obj)
else:
# todo: support files endpoint
# return client.files.create(obj).urls["get"]
raise NotImplementedError("File upload is not supported yet")
response = client.files.create(content=obj.read())
return response.urls.get
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
Expand Down Expand Up @@ -91,9 +90,8 @@ async def async_encode_json(
# TODO: This should ideally use an async based file reader path.
return base64_encode_file(obj)
else:
# todo: support files endpoint
# return (await client.files.async_create(obj)).urls["get"]
raise NotImplementedError("File upload is not supported yet")
response = await client.files.create(content=obj.read())
return response.urls.get
if HAS_NUMPY:
if isinstance(obj, np.integer): # type: ignore
return int(obj)
Expand Down
56 changes: 56 additions & 0 deletions tests/lib/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from respx import MockRouter

from replicate import Replicate, AsyncReplicate
from replicate._compat import model_dump
from replicate.lib._files import FileOutput, AsyncFileOutput
from replicate._exceptions import ModelError, NotFoundError, BadRequestError
from replicate.lib._models import Model, Version, ModelVersionIdentifier
from replicate.types.file_create_response import URLs, Checksums, FileCreateResponse

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
bearer_token = "My Bearer Token"
Expand Down Expand Up @@ -89,6 +91,16 @@ class TestRun:

# Common model reference format that will work with the new SDK
model_ref = "owner/name:version"
file_create_response = FileCreateResponse(
id="test_file_id",
checksums=Checksums(sha256="test_sha256"),
content_type="application/octet-stream",
created_at=datetime.datetime.now(),
expires_at=datetime.datetime.now() + datetime.timedelta(days=1),
metadata={},
size=1234,
urls=URLs(get="https://api.replicate.com/v1/files/test_file_id"),
)

@pytest.mark.respx(base_url=base_url)
def test_run_basic(self, respx_mock: MockRouter) -> None:
Expand Down Expand Up @@ -236,6 +248,23 @@ def test_run_with_base64_file(self, respx_mock: MockRouter) -> None:

assert output == "test output"

@pytest.mark.respx(base_url=base_url)
def test_run_with_file_upload(self, respx_mock: MockRouter) -> None:
"""Test run with base64 encoded file input."""
# Create a simple file-like object
file_obj = io.BytesIO(b"test content")

# Mock the prediction response
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
# Mock the file upload endpoint
respx_mock.post("/files").mock(
return_value=httpx.Response(201, json=model_dump(self.file_create_response, mode="json"))
)

output: Any = self.client.run(self.model_ref, input={"file": file_obj})

assert output == "test output"

def test_run_with_prefer_conflict(self) -> None:
"""Test run with conflicting wait and prefer parameters."""
with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):
Expand Down Expand Up @@ -349,6 +378,16 @@ class TestAsyncRun:

# Common model reference format that will work with the new SDK
model_ref = "owner/name:version"
file_create_response = FileCreateResponse(
id="test_file_id",
checksums=Checksums(sha256="test_sha256"),
content_type="application/octet-stream",
created_at=datetime.datetime.now(),
expires_at=datetime.datetime.now() + datetime.timedelta(days=1),
metadata={},
size=1234,
urls=URLs(get="https://api.replicate.com/v1/files/test_file_id"),
)

@pytest.mark.respx(base_url=base_url)
async def test_async_run_basic(self, respx_mock: MockRouter) -> None:
Expand Down Expand Up @@ -501,6 +540,23 @@ async def test_async_run_with_base64_file(self, respx_mock: MockRouter) -> None:

assert output == "test output"

@pytest.mark.respx(base_url=base_url)
async def test_async_run_with_file_upload(self, respx_mock: MockRouter) -> None:
"""Test async run with base64 encoded file input."""
# Create a simple file-like object
file_obj = io.BytesIO(b"test content")

# Mock the prediction response
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
# Mock the file upload endpoint
respx_mock.post("/files").mock(
return_value=httpx.Response(201, json=model_dump(self.file_create_response, mode="json"))
)

output: Any = await self.client.run(self.model_ref, input={"file": file_obj})

assert output == "test output"

async def test_async_run_with_prefer_conflict(self) -> None:
"""Test async run with conflicting wait and prefer parameters."""
with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):
Expand Down