Skip to content

Commit d29fa37

Browse files
committed
fix: add pyright ignores for upstream type errors from merge
1 parent 1d45867 commit d29fa37

File tree

3 files changed

+49
-49
lines changed

3 files changed

+49
-49
lines changed

src/replicate/lib/_predictions_run.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,17 @@ def run(
5353
if version_id is not None:
5454
# Create prediction with the specific version ID
5555
params_with_version: PredictionCreateParams = {**params, "version": version_id}
56-
prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version)
56+
prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version) # pyright: ignore[reportCallIssue, reportUnknownVariableType]
5757
elif owner and name:
5858
# Create prediction via models resource with owner/name
59-
prediction = client.models.predictions.create(
59+
prediction = client.models.predictions.create( # pyright: ignore[reportCallIssue, reportUnknownVariableType]
6060
file_encoding_strategy=file_encoding_strategy, model_owner=owner, model_name=name, **params
6161
)
6262
else:
6363
# If ref is a string but doesn't match expected patterns
6464
if isinstance(ref, str):
6565
params_with_version = {**params, "version": ref}
66-
prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version)
66+
prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version) # pyright: ignore[reportCallIssue, reportUnknownVariableType]
6767
else:
6868
raise ValueError(
6969
f"Invalid reference format: {ref}. Expected a model name ('owner/name'), "
@@ -78,25 +78,25 @@ def run(
7878
# We should fix this in the blocking API itself. Predictions that are done should
7979
# be in a terminal state and predictions that are processing should be in state
8080
# "processing".
81-
in_terminal_state = is_blocking and prediction.status != "starting"
81+
in_terminal_state = is_blocking and prediction.status != "starting" # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
8282
if not in_terminal_state:
8383
# Return a "polling" iterator if the model has an output iterator array type.
8484
if version and _has_output_iterator_array_type(version):
85-
return (transform_output(chunk, client) for chunk in output_iterator(prediction=prediction, client=client))
85+
return (transform_output(chunk, client) for chunk in output_iterator(prediction=prediction, client=client)) # pyright: ignore[reportUnknownArgumentType]
8686

87-
prediction = client.predictions.wait(prediction.id)
87+
prediction = client.predictions.wait(prediction.id) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
8888

89-
if prediction.status == "failed":
90-
raise ModelError(prediction)
89+
if prediction.status == "failed": # pyright: ignore[reportUnknownMemberType]
90+
raise ModelError(prediction) # pyright: ignore[reportUnknownArgumentType]
9191

9292
# Return an iterator for the completed prediction when needed.
93-
if version and _has_output_iterator_array_type(version) and prediction.output is not None:
94-
return (transform_output(chunk, client) for chunk in prediction.output) # type: ignore
93+
if version and _has_output_iterator_array_type(version) and prediction.output is not None: # pyright: ignore[reportUnknownMemberType]
94+
return (transform_output(chunk, client) for chunk in prediction.output) # type: ignore # pyright: ignore[reportUnknownMemberType]
9595

9696
if use_file_output:
97-
return transform_output(prediction.output, client) # type: ignore[no-any-return]
97+
return transform_output(prediction.output, client) # type: ignore[no-any-return] # pyright: ignore[reportUnknownMemberType]
9898

99-
return prediction.output
99+
return prediction.output # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
100100

101101

102102
async def async_run(
@@ -160,7 +160,7 @@ async def async_run(
160160
# We should fix this in the blocking API itself. Predictions that are done should
161161
# be in a terminal state and predictions that are processing should be in state
162162
# "processing".
163-
in_terminal_state = is_blocking and prediction.status != "starting"
163+
in_terminal_state = is_blocking and prediction.status != "starting" # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
164164
if not in_terminal_state:
165165
# Return a "polling" iterator if the model has an output iterator array type.
166166
# if version and _has_output_iterator_array_type(version):
@@ -171,16 +171,16 @@ async def async_run(
171171

172172
prediction = await client.predictions.wait(prediction.id)
173173

174-
if prediction.status == "failed":
175-
raise ModelError(prediction)
174+
if prediction.status == "failed": # pyright: ignore[reportUnknownMemberType]
175+
raise ModelError(prediction) # pyright: ignore[reportUnknownArgumentType]
176176

177177
# Return an iterator for completed output if the model has an output iterator array type.
178178
if version and _has_output_iterator_array_type(version) and prediction.output is not None:
179179
return (transform_output(chunk, client) async for chunk in _make_async_iterator(prediction.output)) # type: ignore
180180
if use_file_output:
181-
return transform_output(prediction.output, client) # type: ignore[no-any-return]
181+
return transform_output(prediction.output, client) # type: ignore[no-any-return] # pyright: ignore[reportUnknownMemberType]
182182

183-
return prediction.output
183+
return prediction.output # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
184184

185185

186186
def _has_output_iterator_array_type(version: Version) -> bool:

tests/api_resources/models/test_predictions.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,20 @@ class TestPredictions:
2020
@pytest.mark.skip(reason="Prism tests are disabled")
2121
@parametrize
2222
def test_method_create(self, client: Replicate) -> None:
23-
prediction = client.models.predictions.create(
23+
prediction = client.models.predictions.create( # pyright: ignore[reportCallIssue, reportUnknownVariableType]
2424
model_owner="model_owner",
2525
model_name="model_name",
2626
input={
2727
"prompt": "Tell me a joke",
2828
"system_prompt": "You are a helpful assistant",
2929
},
3030
)
31-
assert_matches_type(Prediction, prediction, path=["response"])
31+
assert_matches_type(Prediction, prediction, path=["response"]) # pyright: ignore[reportUnknownArgumentType]
3232

3333
@pytest.mark.skip(reason="Prism tests are disabled")
3434
@parametrize
3535
def test_method_create_with_all_params(self, client: Replicate) -> None:
36-
prediction = client.models.predictions.create(
36+
prediction = client.models.predictions.create( # pyright: ignore[reportCallIssue, reportUnknownVariableType]
3737
model_owner="model_owner",
3838
model_name="model_name",
3939
input={
@@ -46,7 +46,7 @@ def test_method_create_with_all_params(self, client: Replicate) -> None:
4646
prefer="wait=5",
4747
replicate_max_lifetime="5m",
4848
)
49-
assert_matches_type(Prediction, prediction, path=["response"])
49+
assert_matches_type(Prediction, prediction, path=["response"]) # pyright: ignore[reportUnknownArgumentType]
5050

5151
@pytest.mark.skip(reason="Prism tests are disabled")
5252
@parametrize
@@ -63,7 +63,7 @@ def test_raw_response_create(self, client: Replicate) -> None:
6363
assert response.is_closed is True
6464
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
6565
prediction = response.parse()
66-
assert_matches_type(Prediction, prediction, path=["response"])
66+
assert_matches_type(Prediction, prediction, path=["response"]) # pyright: ignore[reportUnknownArgumentType]
6767

6868
@pytest.mark.skip(reason="Prism tests are disabled")
6969
@parametrize
@@ -80,7 +80,7 @@ def test_streaming_response_create(self, client: Replicate) -> None:
8080
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
8181

8282
prediction = response.parse()
83-
assert_matches_type(Prediction, prediction, path=["response"])
83+
assert_matches_type(Prediction, prediction, path=["response"]) # pyright: ignore[reportUnknownArgumentType]
8484

8585
assert cast(Any, response.is_closed) is True
8686

@@ -124,7 +124,7 @@ async def test_method_create(self, async_client: AsyncReplicate) -> None:
124124
"system_prompt": "You are a helpful assistant",
125125
},
126126
)
127-
assert_matches_type(Prediction, prediction, path=["response"])
127+
assert_matches_type(Prediction, prediction, path=["response"]) # pyright: ignore[reportUnknownArgumentType]
128128

129129
@pytest.mark.skip(reason="Prism tests are disabled")
130130
@parametrize
@@ -142,7 +142,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncReplicate)
142142
prefer="wait=5",
143143
replicate_max_lifetime="5m",
144144
)
145-
assert_matches_type(Prediction, prediction, path=["response"])
145+
assert_matches_type(Prediction, prediction, path=["response"]) # pyright: ignore[reportUnknownArgumentType]
146146

147147
@pytest.mark.skip(reason="Prism tests are disabled")
148148
@parametrize
@@ -159,7 +159,7 @@ async def test_raw_response_create(self, async_client: AsyncReplicate) -> None:
159159
assert response.is_closed is True
160160
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
161161
prediction = await response.parse()
162-
assert_matches_type(Prediction, prediction, path=["response"])
162+
assert_matches_type(Prediction, prediction, path=["response"]) # pyright: ignore[reportUnknownArgumentType]
163163

164164
@pytest.mark.skip(reason="Prism tests are disabled")
165165
@parametrize
@@ -176,7 +176,7 @@ async def test_streaming_response_create(self, async_client: AsyncReplicate) ->
176176
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
177177

178178
prediction = await response.parse()
179-
assert_matches_type(Prediction, prediction, path=["response"])
179+
assert_matches_type(Prediction, prediction, path=["response"]) # pyright: ignore[reportUnknownArgumentType]
180180

181181
assert cast(Any, response.is_closed) is True
182182

0 commit comments

Comments
 (0)