Skip to content

Commit 9f4c4e4

Browse files
authored
Add Python docstrings for classes, attributes, and methods (#129)
* Add docstrings to methods Signed-off-by: Mattt Zmuda <[email protected]> * Add docstrings to classes and fields Signed-off-by: Mattt Zmuda <[email protected]> * Add docstring with deprecation warning Signed-off-by: Mattt Zmuda <[email protected]> * Adopt grammatical mood of PEP 257 Signed-off-by: Mattt Zmuda <[email protected]> * Format Collection class docs docstring Signed-off-by: Mattt Zmuda <[email protected]> * Remove implementation notes about Cog from documentation strings Signed-off-by: Mattt Zmuda <[email protected]> * Remove type information from method args Signed-off-by: Mattt Zmuda <[email protected]> --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 6f35c4f commit 9f4c4e4

File tree

8 files changed

+220
-20
lines changed

8 files changed

+220
-20
lines changed

replicate/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,13 @@ def trainings(self) -> TrainingCollection:
115115

116116
def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
117117
"""
118-
Run a model in the format owner/name:version.
118+
Run a model and wait for its output.
119+
120+
Args:
121+
model_version: The model version to run, in the format `owner/name:version`
122+
kwargs: The input to the model, as a dictionary
123+
Returns:
124+
The output of the model
119125
"""
120126
# Split model_version into owner, name, version in format owner/name:version
121127
m = re.match(r"^(?P<model>[^/]+/[^:]+):(?P<version>.+)$", model_version)

replicate/collection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111

1212
class Collection(abc.ABC, Generic[Model]):
1313
"""
14-
A base class for representing all objects of a particular type on the
15-
server.
14+
A base class for representing objects of a particular type on the server.
1615
"""
1716

1817
def __init__(self, client: "Client") -> None:

replicate/files.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,16 @@
99

1010
def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
1111
"""
12-
Lifted straight from cog.files
12+
Upload a file to the server.
13+
14+
Args:
15+
fh: A file handle to upload.
16+
output_file_prefix: A string to prepend to the output file name.
17+
Returns:
18+
str: A URL to the uploaded file.
1319
"""
20+
# Lifted straight from cog.files
21+
1422
fh.seek(0)
1523

1624
if output_file_prefix is not None:

replicate/json.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ def encode_json(
1515
obj: Any, upload_file: Callable[[io.IOBase], str] # noqa: ANN401
1616
) -> Any: # noqa: ANN401
1717
"""
18-
Returns a JSON-compatible version of the object. Effectively the same thing as cog.json.encode_json.
18+
Return a JSON-compatible version of the object.
1919
"""
20+
# Effectively the same thing as cog.json.encode_json.
21+
2022
if isinstance(obj, dict):
2123
return {key: encode_json(value, upload_file) for key, value in obj.items()}
2224
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):

replicate/model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,35 @@
77

88

99
class Model(BaseModel):
10+
"""
11+
A machine learning model hosted on Replicate.
12+
"""
13+
1014
username: str
15+
"""
16+
The name of the user or organization that owns the model.
17+
"""
18+
1119
name: str
20+
"""
21+
The name of the model.
22+
"""
1223

1324
def predict(self, *args, **kwargs) -> None:
25+
"""
26+
DEPRECATED: Use `version.predict()` instead.
27+
"""
28+
1429
raise ReplicateException(
1530
"The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme"
1631
)
1732

1833
@property
1934
def versions(self) -> VersionCollection:
35+
"""
36+
Get the versions of this model.
37+
"""
38+
2039
return VersionCollection(client=self._client, model=self)
2140

2241

@@ -27,6 +46,15 @@ def list(self) -> List[Model]:
2746
raise NotImplementedError()
2847

2948
def get(self, name: str) -> Model:
49+
"""
50+
Get a model by name.
51+
52+
Args:
53+
name: The name of the model, in the format `owner/model-name`.
54+
Returns:
55+
The model.
56+
"""
57+
3058
# TODO: fetch model from server
3159
# TODO: support permanent IDs
3260
username, name = name.split("/")

replicate/prediction.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,53 @@
1010

1111

1212
class Prediction(BaseModel):
13+
"""
14+
A prediction made by a model hosted on Replicate.
15+
"""
16+
1317
id: str
14-
error: Optional[str]
18+
"""The unique ID of the prediction."""
19+
20+
version: Optional[Version]
21+
"""The version of the model used to create the prediction."""
22+
23+
status: str
24+
"""The status of the prediction."""
25+
1526
input: Optional[Dict[str, Any]]
16-
logs: Optional[str]
27+
"""The input to the prediction."""
28+
1729
output: Optional[Any]
18-
status: str
19-
version: Optional[Version]
20-
started_at: Optional[str]
30+
"""The output of the prediction."""
31+
32+
logs: Optional[str]
33+
"""The logs of the prediction."""
34+
35+
error: Optional[str]
36+
"""The error encountered during the prediction, if any."""
37+
2138
created_at: Optional[str]
39+
"""When the prediction was created."""
40+
41+
started_at: Optional[str]
42+
"""When the prediction was started."""
43+
2244
completed_at: Optional[str]
45+
"""When the prediction was completed, if finished."""
46+
2347
urls: Optional[Dict[str, str]]
48+
"""
49+
URLs associated with the prediction.
50+
51+
The following keys are available:
52+
- `get`: A URL to fetch the prediction.
53+
- `cancel`: A URL to cancel the prediction.
54+
"""
2455

2556
def wait(self) -> None:
26-
"""Wait for prediction to finish."""
57+
"""
58+
Wait for prediction to finish.
59+
"""
2760
while self.status not in ["succeeded", "failed", "canceled"]:
2861
time.sleep(self._client.poll_interval)
2962
self.reload()
@@ -48,14 +81,23 @@ def output_iterator(self) -> Iterator[Any]:
4881
yield output
4982

5083
def cancel(self) -> None:
51-
"""Cancel a currently running prediction"""
84+
"""
85+
Cancels a running prediction.
86+
"""
5287
self._client._request("POST", f"/v1/predictions/{self.id}/cancel")
5388

5489

5590
class PredictionCollection(Collection):
5691
model = Prediction
5792

5893
def list(self) -> List[Prediction]:
94+
"""
95+
List your predictions.
96+
97+
Returns:
98+
A list of prediction objects.
99+
"""
100+
59101
resp = self._client._request("GET", "/v1/predictions")
60102
# TODO: paginate
61103
predictions = resp.json()["results"]
@@ -65,6 +107,15 @@ def list(self) -> List[Prediction]:
65107
return [self.prepare_model(obj) for obj in predictions]
66108

67109
def get(self, id: str) -> Prediction:
110+
"""
111+
Get a prediction by ID.
112+
113+
Args:
114+
id: The ID of the prediction.
115+
Returns:
116+
Prediction: The prediction object.
117+
"""
118+
68119
resp = self._client._request("GET", f"/v1/predictions/{id}")
69120
obj = resp.json()
70121
# HACK: resolve this? make it lazy somehow?
@@ -80,6 +131,21 @@ def create( # type: ignore
80131
webhook_events_filter: Optional[List[str]] = None,
81132
**kwargs,
82133
) -> Prediction:
134+
"""
135+
Create a new prediction for the specified model version.
136+
137+
Args:
138+
version: The model version to use for the prediction.
139+
input: The input data for the prediction.
140+
webhook: The URL to receive a POST request with prediction updates.
141+
webhook_completed: The URL to receive a POST request when the prediction is completed.
142+
webhook_events_filter: List of events to trigger webhooks.
143+
stream: Set to True to enable streaming of prediction output.
144+
145+
Returns:
146+
Prediction: The created prediction object.
147+
"""
148+
83149
input = encode_json(input, upload_file=upload_file)
84150
body = {
85151
"version": version.id,

replicate/training.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,51 @@
1010

1111

1212
class Training(BaseModel):
13-
completed_at: Optional[str]
14-
created_at: Optional[str]
15-
destination: Optional[str]
16-
error: Optional[str]
13+
"""
14+
A training made for a model hosted on Replicate.
15+
"""
16+
1717
id: str
18+
"""The unique ID of the training."""
19+
20+
version: Optional[Version]
21+
"""The version of the model used to create the training."""
22+
23+
destination: Optional[str]
24+
"""The model destination of the training."""
25+
26+
status: str
27+
"""The status of the training."""
28+
1829
input: Optional[Dict[str, Any]]
19-
logs: Optional[str]
30+
"""The input to the training."""
31+
2032
output: Optional[Any]
33+
"""The output of the training."""
34+
35+
logs: Optional[str]
36+
"""The logs of the training."""
37+
38+
error: Optional[str]
39+
"""The error encountered during the training, if any."""
40+
41+
created_at: Optional[str]
42+
"""When the training was created."""
43+
2144
started_at: Optional[str]
22-
status: str
23-
version: Optional[Version]
45+
"""When the training was started."""
46+
47+
completed_at: Optional[str]
48+
"""When the training was completed, if finished."""
49+
50+
urls: Optional[Dict[str, str]]
51+
"""
52+
URLs associated with the training.
53+
54+
The following keys are available:
55+
- `get`: A URL to fetch the training.
56+
- `cancel`: A URL to cancel the training.
57+
"""
2458

2559
def cancel(self) -> None:
2660
"""Cancel a running training"""
@@ -31,6 +65,13 @@ class TrainingCollection(Collection):
3165
model = Training
3266

3367
def list(self) -> List[Training]:
68+
"""
69+
List your trainings.
70+
71+
Returns:
72+
List[Training]: A list of training objects.
73+
"""
74+
3475
resp = self._client._request("GET", "/v1/trainings")
3576
# TODO: paginate
3677
trainings = resp.json()["results"]
@@ -40,6 +81,15 @@ def list(self) -> List[Training]:
4081
return [self.prepare_model(obj) for obj in trainings]
4182

4283
def get(self, id: str) -> Training:
84+
"""
85+
Get a training by ID.
86+
87+
Args:
88+
id: The ID of the training.
89+
Returns:
90+
Training: The training object.
91+
"""
92+
4393
resp = self._client._request(
4494
"GET",
4595
f"/v1/trainings/{id}",
@@ -58,6 +108,19 @@ def create( # type: ignore
58108
webhook_events_filter: Optional[List[str]] = None,
59109
**kwargs,
60110
) -> Training:
111+
"""
112+
Create a new training using the specified model version as a base.
113+
114+
Args:
115+
version: The ID of the base model version that you're using to train a new model version.
116+
input: The input to the training.
117+
destination: The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request.
118+
webhook: The URL to send a POST request to when the training is completed. Defaults to None.
119+
webhook_events_filter: The events to send to the webhook. Defaults to None.
120+
Returns:
121+
The training object.
122+
"""
123+
61124
input = encode_json(input, upload_file=upload_file)
62125
body = {
63126
"input": input,

replicate/version.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,32 @@
1414

1515

1616
class Version(BaseModel):
17+
"""
18+
A version of a model.
19+
"""
20+
1721
id: str
22+
"""The unique ID of the version."""
23+
1824
created_at: datetime.datetime
25+
"""When the version was created."""
26+
1927
cog_version: str
28+
"""The version of the Cog used to create the version."""
29+
2030
openapi_schema: dict
31+
"""An OpenAPI description of the model inputs and outputs."""
2132

2233
def predict(self, **kwargs) -> Union[Any, Iterator[Any]]:
34+
"""
35+
Create a prediction using this model version.
36+
37+
Args:
38+
kwargs: The input to the model.
39+
Returns:
40+
The output of the model.
41+
"""
42+
2343
warnings.warn(
2444
"version.predict() is deprecated. Use replicate.run() instead. It will be removed before version 1.0.",
2545
DeprecationWarning,
@@ -57,7 +77,12 @@ def __init__(self, client: "Client", model: "Model") -> None:
5777
# doesn't exist yet
5878
def get(self, id: str) -> Version:
5979
"""
60-
Get a specific version.
80+
Get a specific model version.
81+
82+
Args:
83+
id: The version ID.
84+
Returns:
85+
The model version.
6186
"""
6287
resp = self._client._request(
6388
"GET", f"/v1/models/{self._model.username}/{self._model.name}/versions/{id}"
@@ -70,6 +95,9 @@ def create(self, **kwargs) -> Version:
7095
def list(self) -> List[Version]:
7196
"""
7297
Return a list of all versions for a model.
98+
99+
Returns:
100+
List[Version]: A list of version objects.
73101
"""
74102
resp = self._client._request(
75103
"GET", f"/v1/models/{self._model.username}/{self._model.name}/versions"

0 commit comments

Comments
 (0)