Skip to content

Commit 0ec2897

Browse files
committed
fix files uploading
1 parent aef2230 commit 0ec2897

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

src/replicate/lib/_files.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ def encode_json(
5151
if file_encoding_strategy == "base64":
5252
return base64_encode_file(obj)
5353
else:
54-
# todo: support files endpoint
55-
# return client.files.create(obj).urls["get"]
56-
raise NotImplementedError("File upload is not supported yet")
54+
response = client.files.create(content=obj.read())
55+
return response.urls.get
5756
if HAS_NUMPY:
5857
if isinstance(obj, np.integer): # type: ignore
5958
return int(obj)
@@ -91,9 +90,8 @@ async def async_encode_json(
9190
# TODO: This should ideally use an async based file reader path.
9291
return base64_encode_file(obj)
9392
else:
94-
# todo: support files endpoint
95-
# return (await client.files.async_create(obj)).urls["get"]
96-
raise NotImplementedError("File upload is not supported yet")
93+
response = await client.files.create(content=obj.read())
94+
return response.urls.get
9795
if HAS_NUMPY:
9896
if isinstance(obj, np.integer): # type: ignore
9997
return int(obj)

src/replicate/lib/_predictions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,6 @@ def output_iterator(prediction: Prediction, client: Replicate) -> Iterator[Any]:
215215
if prediction.status == "failed":
216216
raise ModelError(prediction=prediction)
217217

218-
output: list[Any] = prediction.output or [] # type: ignore[union-attr]
218+
output = prediction.output or [] # type: ignore
219219
new_output = output[len(previous_output) :]
220220
yield from new_output

src/replicate/resources/files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def create(
5858
self,
5959
*,
6060
content: FileTypes,
61-
filename: str,
61+
filename: str | NotGiven = NOT_GIVEN,
6262
metadata: object | NotGiven = NOT_GIVEN,
6363
type: str | NotGiven = NOT_GIVEN,
6464
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
@@ -336,7 +336,7 @@ async def create(
336336
self,
337337
*,
338338
content: FileTypes,
339-
filename: str,
339+
filename: str | NotGiven = NOT_GIVEN,
340340
metadata: object | NotGiven = NOT_GIVEN,
341341
type: str | NotGiven = NOT_GIVEN,
342342
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.

0 commit comments

Comments
 (0)