Skip to content

Commit 7b1a1cc

Browse files
committed
forgot the file
1 parent b969e71 commit 7b1a1cc

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

src/replicate/lib/_models.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from __future__ import annotations
2+
3+
from typing import Tuple, Union, Optional
4+
from typing_extensions import TypedDict
5+
6+
7+
class Model:
8+
"""A Replicate model."""
9+
10+
def __init__(self, owner: str, name: str):
11+
self.owner = owner
12+
self.name = name
13+
14+
15+
class Version:
16+
"""A specific version of a Replicate model."""
17+
18+
def __init__(self, id: str):
19+
self.id = id
20+
21+
22+
class ModelVersionIdentifier(TypedDict, total=False):
23+
"""A structure to identify a model version."""
24+
25+
owner: str
26+
name: str
27+
version: str
28+
29+
30+
def resolve_reference(
31+
ref: Union[Model, Version, ModelVersionIdentifier, str],
32+
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
33+
"""
34+
Resolve a reference to a model or version to its components.
35+
36+
Returns a tuple of (version, owner, name, version_id).
37+
"""
38+
version = None
39+
owner = None
40+
name = None
41+
version_id = None
42+
43+
if isinstance(ref, Model):
44+
owner = ref.owner
45+
name = ref.name
46+
elif isinstance(ref, Version):
47+
version_id = ref.id
48+
elif isinstance(ref, dict):
49+
owner = ref.get("owner")
50+
name = ref.get("name")
51+
version_id = ref.get("version")
52+
else:
53+
# Check if the string is a version ID (assumed to be a hash-like string)
54+
if "/" not in ref and len(ref) >= 32:
55+
version_id = ref
56+
else:
57+
# Handle owner/name or owner/name/version format
58+
parts = ref.split("/")
59+
if len(parts) >= 2:
60+
owner = parts[0]
61+
name = parts[1]
62+
if len(parts) >= 3:
63+
version_id = parts[2]
64+
65+
return version, owner, name, version_id

0 commit comments

Comments
 (0)