Skip to content

Commit f080e81

Browse files
committed
Support list output
The counterpart to replicate/cog#655 Backwards compatibility functions have been lifted from Replicate. Maybe we should make Cog the source of truth for these... Signed-off-by: Ben Firshman <[email protected]>
1 parent d63c6af commit f080e81

File tree

4 files changed

+164
-140
lines changed

4 files changed

+164
-140
lines changed

replicate/schema.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from packaging import version
2+
3+
# TODO: this code is shared with replicate's backend. Maybe we should put it in the Cog Python package as the source of truth?
4+
5+
6+
def version_has_no_array_type(cog_version):
7+
"""Iterators have x-cog-array-type=iterator in the schema from 0.3.9 onward"""
8+
return version.parse(cog_version) < version.parse("0.3.9")
9+
10+
11+
def make_schema_backwards_compatible(schema, version):
12+
"""A place to add backwards compatibility logic for our openapi schema"""
13+
# If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type
14+
if version_has_no_array_type(version):
15+
output = schema["components"]["schemas"]["Output"]
16+
if output.get("type") == "array":
17+
output["x-cog-array-type"] = "iterator"
18+
return schema

replicate/version.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from replicate.base_model import BaseModel
55
from replicate.collection import Collection
66
from replicate.exceptions import ModelError
7+
from replicate.schema import make_schema_backwards_compatible
78

89

910
class Version(BaseModel):
@@ -17,9 +18,11 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]:
1718
prediction = self._client.predictions.create(version=self, input=kwargs)
1819
# Return an iterator of the output
1920
# FIXME: might just be a list, not an iterator. I wonder if we should differentiate?
21+
schema = self.get_transformed_schema()
22+
output = schema["components"]["schemas"]["Output"]
2023
if (
21-
self.openapi_schema["components"]["schemas"]["Output"].get("type")
22-
== "array"
24+
output.get("type") == "array"
25+
and output.get("x-cog-array-type") == "iterator"
2326
):
2427
return prediction.output_iterator()
2528

@@ -28,6 +31,11 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]:
2831
raise ModelError(prediction.error)
2932
return prediction.output
3033

34+
def get_transformed_schema(self):
35+
schema = self.openapi_schema
36+
schema = make_schema_backwards_compatible(schema, self.cog_version)
37+
return schema
38+
3139

3240
class VersionCollection(Collection):
3341
model = Version

tests/factories.py

Lines changed: 30 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ def create_client():
99
return client
1010

1111

12-
def create_version(client=None, openapi_schema=None):
12+
def create_version(client=None, openapi_schema=None, cog_version="0.3.0"):
1313
if client is None:
1414
client = create_client()
1515
version = Version(
1616
id="v1",
1717
created_at=datetime.datetime.now(),
18-
cog_version="0.3.0",
18+
cog_version=cog_version,
1919
openapi_schema=openapi_schema
2020
or {
2121
"info": {"title": "Cog", "version": "0.1.0"},
@@ -156,138 +156,31 @@ def create_version(client=None, openapi_schema=None):
156156

157157

158158
def create_version_with_iterator_output():
159-
return create_version(
160-
openapi_schema={
161-
"info": {"title": "Cog", "version": "0.1.0"},
162-
"paths": {
163-
"/": {
164-
"get": {
165-
"summary": "Root",
166-
"responses": {
167-
"200": {
168-
"content": {"application/json": {"schema": {}}},
169-
"description": "Successful Response",
170-
}
171-
},
172-
"operationId": "root__get",
173-
}
174-
},
175-
"/predictions": {
176-
"post": {
177-
"summary": "Predict",
178-
"responses": {
179-
"200": {
180-
"content": {
181-
"application/json": {
182-
"schema": {
183-
"$ref": "#/components/schemas/Response"
184-
}
185-
}
186-
},
187-
"description": "Successful Response",
188-
},
189-
"422": {
190-
"content": {
191-
"application/json": {
192-
"schema": {
193-
"$ref": "#/components/schemas/HTTPValidationError"
194-
}
195-
}
196-
},
197-
"description": "Validation Error",
198-
},
199-
},
200-
"description": "Run a single prediction on the model.",
201-
"operationId": "predict_predictions_post",
202-
"requestBody": {
203-
"content": {
204-
"application/json": {
205-
"schema": {"$ref": "#/components/schemas/Request"}
206-
}
207-
}
208-
},
209-
}
210-
},
211-
},
212-
"openapi": "3.0.2",
213-
"components": {
214-
"schemas": {
215-
"Input": {
216-
"type": "object",
217-
"title": "Input",
218-
"properties": {
219-
"prompts": {
220-
"type": "string",
221-
"title": "Prompts",
222-
"default": "Cairo skyline at sunset.",
223-
"x-order": 0,
224-
"description": "text prompt",
225-
},
226-
},
227-
},
228-
"Output": {
229-
"type": "array",
230-
"items": {"type": "string"},
231-
"title": "Output",
232-
},
233-
"Status": {
234-
"enum": ["processing", "success", "failed"],
235-
"type": "string",
236-
"title": "Status",
237-
"description": "An enumeration.",
238-
},
239-
"Request": {
240-
"type": "object",
241-
"title": "Request",
242-
"properties": {
243-
"input": {"$ref": "#/components/schemas/Input"},
244-
"output_file_prefix": {
245-
"type": "string",
246-
"title": "Output File Prefix",
247-
},
248-
},
249-
},
250-
"Response": {
251-
"type": "object",
252-
"title": "Response",
253-
"required": ["status"],
254-
"properties": {
255-
"error": {"type": "string", "title": "Error"},
256-
"output": {"$ref": "#/components/schemas/Output"},
257-
"status": {"$ref": "#/components/schemas/Status"},
258-
},
259-
"description": "The status of a prediction.",
260-
},
261-
"ValidationError": {
262-
"type": "object",
263-
"title": "ValidationError",
264-
"required": ["loc", "msg", "type"],
265-
"properties": {
266-
"loc": {
267-
"type": "array",
268-
"items": {
269-
"anyOf": [{"type": "string"}, {"type": "integer"}]
270-
},
271-
"title": "Location",
272-
},
273-
"msg": {"type": "string", "title": "Message"},
274-
"type": {"type": "string", "title": "Error Type"},
275-
},
276-
},
277-
"HTTPValidationError": {
278-
"type": "object",
279-
"title": "HTTPValidationError",
280-
"properties": {
281-
"detail": {
282-
"type": "array",
283-
"items": {
284-
"$ref": "#/components/schemas/ValidationError"
285-
},
286-
"title": "Detail",
287-
}
288-
},
289-
},
290-
}
291-
},
292-
}
293-
)
159+
version = create_version(cog_version="0.3.9")
160+
version.openapi_schema["components"]["schemas"]["Output"] = {
161+
"type": "array",
162+
"items": {"type": "string"},
163+
"title": "Output",
164+
"x-cog-array-type": "iterator",
165+
}
166+
return version
167+
168+
169+
def create_version_with_list_output():
170+
version = create_version(cog_version="0.3.9")
171+
version.openapi_schema["components"]["schemas"]["Output"] = {
172+
"type": "array",
173+
"items": {"type": "string"},
174+
"title": "Output",
175+
}
176+
return version
177+
178+
179+
def create_version_with_iterator_output_backwards_compatibility_0_3_8():
180+
version = create_version(cog_version="0.3.8")
181+
version.openapi_schema["components"]["schemas"]["Output"] = {
182+
"type": "array",
183+
"items": {"type": "string"},
184+
"title": "Output",
185+
}
186+
return version

tests/test_version.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
from replicate.exceptions import ModelError
66
from responses import matchers
77

8-
from .factories import create_version, create_version_with_iterator_output
8+
from .factories import (
9+
create_version,
10+
create_version_with_iterator_output,
11+
create_version_with_iterator_output_backwards_compatibility_0_3_8,
12+
create_version_with_list_output,
13+
)
914

1015

1116
@responses.activate
@@ -107,6 +112,106 @@ def test_predict_with_iterator():
107112
assert list(output) == ["hello world"]
108113

109114

115+
@responses.activate
116+
def test_predict_with_list():
117+
version = create_version_with_list_output()
118+
responses.post(
119+
"https://api.replicate.com/v1/predictions",
120+
match=[
121+
matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}})
122+
],
123+
json={
124+
"id": "p1",
125+
"version": "v1",
126+
"urls": {
127+
"get": "https://api.replicate.com/v1/predictions/p1",
128+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
129+
},
130+
"created_at": "2022-04-26T20:00:40.658234Z",
131+
"completed_at": "2022-04-26T20:02:27.648305Z",
132+
"source": "api",
133+
"status": "processing",
134+
"input": {"text": "world"},
135+
"output": None,
136+
"error": None,
137+
"logs": "",
138+
},
139+
)
140+
responses.get(
141+
"https://api.replicate.com/v1/predictions/p1",
142+
json={
143+
"id": "p1",
144+
"version": "v1",
145+
"urls": {
146+
"get": "https://api.replicate.com/v1/predictions/p1",
147+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
148+
},
149+
"created_at": "2022-04-26T20:00:40.658234Z",
150+
"completed_at": "2022-04-26T20:02:27.648305Z",
151+
"source": "api",
152+
"status": "succeeded",
153+
"input": {"text": "world"},
154+
"output": ["hello world"],
155+
"error": None,
156+
"logs": "",
157+
},
158+
)
159+
160+
output = version.predict(text="world")
161+
assert isinstance(output, list)
162+
assert output == ["hello world"]
163+
164+
165+
@responses.activate
166+
def test_predict_with_iterator_backwards_compatibility_cog_0_3_8():
167+
version = create_version_with_iterator_output_backwards_compatibility_0_3_8()
168+
responses.post(
169+
"https://api.replicate.com/v1/predictions",
170+
match=[
171+
matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}})
172+
],
173+
json={
174+
"id": "p1",
175+
"version": "v1",
176+
"urls": {
177+
"get": "https://api.replicate.com/v1/predictions/p1",
178+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
179+
},
180+
"created_at": "2022-04-26T20:00:40.658234Z",
181+
"completed_at": "2022-04-26T20:02:27.648305Z",
182+
"source": "api",
183+
"status": "processing",
184+
"input": {"text": "world"},
185+
"output": None,
186+
"error": None,
187+
"logs": "",
188+
},
189+
)
190+
responses.get(
191+
"https://api.replicate.com/v1/predictions/p1",
192+
json={
193+
"id": "p1",
194+
"version": "v1",
195+
"urls": {
196+
"get": "https://api.replicate.com/v1/predictions/p1",
197+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
198+
},
199+
"created_at": "2022-04-26T20:00:40.658234Z",
200+
"completed_at": "2022-04-26T20:02:27.648305Z",
201+
"source": "api",
202+
"status": "succeeded",
203+
"input": {"text": "world"},
204+
"output": ["hello world"],
205+
"error": None,
206+
"logs": "",
207+
},
208+
)
209+
210+
output = version.predict(text="world")
211+
assert isinstance(output, Iterable)
212+
assert list(output) == ["hello world"]
213+
214+
110215
@responses.activate
111216
def test_predict_with_iterator_with_failed_prediction():
112217
version = create_version_with_iterator_output()

0 commit comments

Comments
 (0)