Skip to content

Commit 999c38b

Browse files
authored
Add webhook_completed to prediction (#42)
Signed-off-by: Ben Firshman <[email protected]> Signed-off-by: Ben Firshman <[email protected]>
1 parent 7432e91 commit 999c38b

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

replicate/prediction.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,24 @@ def cancel(self):
5252
class PredictionCollection(Collection):
5353
model = Prediction
5454

55-
def create(self, version: Version, input: Dict[str, Any]) -> Prediction:
55+
def create(
56+
self,
57+
version: Version,
58+
input: Dict[str, Any],
59+
webhook_completed: Optional[str] = None,
60+
) -> Prediction:
5661
input = encode_json(input, upload_file=upload_file)
62+
body = {
63+
"version": version.id,
64+
"input": input,
65+
}
66+
if webhook_completed is not None:
67+
body["webhook_completed"] = webhook_completed
68+
5769
resp = self._client._request(
58-
"POST", "/v1/predictions", json={"version": version.id, "input": input}
70+
"POST",
71+
"/v1/predictions",
72+
json=body,
5973
)
6074
obj = resp.json()
6175
obj["version"] = version

tests/test_prediction.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@ def test_cancel():
1313
responses.post(
1414
"https://api.replicate.com/v1/predictions",
1515
match=[
16-
matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}})
16+
matchers.json_params_matcher(
17+
{
18+
"version": "v1",
19+
"input": {"text": "world"},
20+
"webhook_completed": "https://example.com/webhook",
21+
}
22+
),
1723
],
1824
json={
1925
"id": "p1",
@@ -33,7 +39,11 @@ def test_cancel():
3339
},
3440
)
3541

36-
prediction = client.predictions.create(version=version, input={"text": "world"})
42+
prediction = client.predictions.create(
43+
version=version,
44+
input={"text": "world"},
45+
webhook_completed="https://example.com/webhook",
46+
)
3747

3848
rsp = responses.post("https://api.replicate.com/v1/predictions/p1/cancel", json={})
3949
prediction.cancel()

0 commit comments

Comments
 (0)