Skip to content

Commit 24d30f4

Browse files
committed
updates to support async example
1 parent cf383c6 commit 24d30f4

File tree

8 files changed

+357
-25
lines changed

8 files changed

+357
-25
lines changed
File renamed without changes.

examples/run_async.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import asyncio
2+
3+
from replicate import AsyncReplicate
4+
5+
client = AsyncReplicate()
6+
7+
# https://replicate.com/stability-ai/sdxl
8+
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
9+
prompts = [f"A chariot pulled by a team of {count} rainbow unicorns" for count in ["two", "four", "six", "eight"]]
10+
11+
12+
async def main() -> None:
13+
# Create tasks with asyncio.gather directly
14+
tasks = [client.run(model_version, input={"prompt": prompt}) for prompt in prompts]
15+
16+
results = await asyncio.gather(*tasks)
17+
print(results)
18+
19+
20+
asyncio.run(main())

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"anyio>=3.5.0, <5",
1515
"distro>=1.7.0, <2",
1616
"sniffio",
17+
"asyncio>=3.4.3",
1718
]
1819
requires-python = ">= 3.8"
1920
classifiers = [

src/replicate/lib/_models.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, Tuple, Union, Optional
4-
from typing_extensions import TypedDict
3+
import re
4+
from typing import Any, Dict, Tuple, Union, Optional, NamedTuple
55

66

77
class Model:
@@ -35,12 +35,26 @@ class Version(BaseModel):
3535
"""An OpenAPI description of the model inputs and outputs."""
3636

3737

38-
class ModelVersionIdentifier(TypedDict, total=False):
39-
"""A structure to identify a model version."""
38+
class ModelVersionIdentifier(NamedTuple):
39+
"""
40+
A reference to a model version in the format owner/name or owner/name:version.
41+
"""
4042

4143
owner: str
4244
name: str
43-
version: str
45+
version: Optional[str] = None
46+
47+
@classmethod
48+
def parse(cls, ref: str) -> "ModelVersionIdentifier":
49+
"""
50+
Split a reference in the format owner/name:version into its components.
51+
"""
52+
53+
match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^/:]+)(:(?P<version>.+))?$", ref)
54+
if not match:
55+
raise ValueError(f"Invalid reference to model version: {ref}. Expected format: owner/name:version")
56+
57+
return cls(match.group("owner"), match.group("name"), match.group("version"))
4458

4559

4660
def resolve_reference(
@@ -57,25 +71,13 @@ def resolve_reference(
5771
version_id = None
5872

5973
if isinstance(ref, Model):
60-
owner = ref.owner
61-
name = ref.name
74+
owner, name = ref.owner, ref.name
6275
elif isinstance(ref, Version):
76+
version = ref
6377
version_id = ref.id
64-
elif isinstance(ref, dict):
65-
owner = ref.get("owner")
66-
name = ref.get("name")
67-
version_id = ref.get("version")
78+
elif isinstance(ref, ModelVersionIdentifier):
79+
owner, name, version_id = ref
6880
else:
69-
# Check if the string is a version ID (assumed to be a hash-like string)
70-
if "/" not in ref and len(ref) >= 32:
71-
version_id = ref
72-
else:
73-
# Handle owner/name or owner/name/version format
74-
parts = ref.split("/")
75-
if len(parts) >= 2:
76-
owner = parts[0]
77-
name = parts[1]
78-
if len(parts) >= 3:
79-
version_id = parts[2]
81+
owner, name, version_id = ModelVersionIdentifier.parse(ref)
8082

8183
return version, owner, name, version_id

src/replicate/lib/_predictions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ async def async_run(
242242
def _has_output_iterator_array_type(version: Version) -> bool:
243243
schema = make_schema_backwards_compatible(version.openapi_schema, version.cog_version)
244244
output = schema.get("components", {}).get("schemas", {}).get("Output", {})
245-
return output.get("type") == "array" and output.get("x-cog-array-type") == "iterator" # type: ignore[no-any-return]
245+
return output.get("type") == "array" and output.get("x-cog-array-type") == "iterator" # type: ignore[no-any-return]
246246

247247

248248
async def _make_async_iterator(list: List[Any]) -> AsyncIterator[Any]:

src/replicate/lib/_schema.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Any, Dict, Optional
2+
3+
from packaging import version
4+
5+
# TODO: this code is shared with replicate's backend. Maybe we should put it in the Cog Python package as the source of truth?
6+
7+
8+
def version_has_no_array_type(cog_version: str) -> Optional[bool]:
9+
"""Iterators have x-cog-array-type=iterator in the schema from 0.3.9 onward"""
10+
try:
11+
return version.parse(cog_version) < version.parse("0.3.9")
12+
except version.InvalidVersion:
13+
return None
14+
15+
16+
def make_schema_backwards_compatible(
17+
schema: Dict[str, Any],
18+
cog_version: str,
19+
) -> Dict[str, Any]:
20+
"""A place to add backwards compatibility logic for our openapi schema"""
21+
22+
# If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type
23+
if version_has_no_array_type(cog_version):
24+
output = schema["components"]["schemas"]["Output"]
25+
if output.get("type") == "array":
26+
output["x-cog-array-type"] = "iterator"
27+
return schema

tests/lib/test_run.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ def test_run_with_model_version_identifier(self, respx_mock: MockRouter) -> None
274274
# Case where version ID is provided
275275
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
276276

277-
identifier: ModelVersionIdentifier = {"owner": "test-owner", "name": "test-model", "version": "test-version-id"}
277+
identifier = ModelVersionIdentifier(
278+
owner="test-owner", name="test-model", version="test-version-id"
279+
)
278280
output = self.client.run(identifier, input={"prompt": "test prompt"})
279281

280282
assert output == "test output"
@@ -506,7 +508,9 @@ async def test_async_run_with_model_version_identifier(self, respx_mock: MockRou
506508
# Case where version ID is provided
507509
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
508510

509-
identifier: ModelVersionIdentifier = {"owner": "test-owner", "name": "test-model", "version": "test-version-id"}
511+
identifier = ModelVersionIdentifier(
512+
owner="test-owner", name="test-model", version="test-version-id"
513+
)
510514
output = await self.client.run(identifier, input={"prompt": "test prompt"})
511515

512516
assert output == "test output"

0 commit comments

Comments
 (0)