Skip to content

Commit 7ebd598

Browse files
feat(api): api update
1 parent 00eab77 commit 7ebd598

File tree

14 files changed

+379
-67
lines changed

14 files changed

+379
-67
lines changed

.stats.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
configured_endpoints: 27
2-
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-37bb31ed76da599d3bded543a3765f745c8575d105c13554df7f8361c3641482.yml
3-
openapi_spec_hash: 15bdec12ca84042768bfb28cc48dfce3
2+
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml
3+
openapi_spec_hash: 4423bf747e228484547b441468a9f156
44
config_hash: 64fd304bf0ff077a74053ce79a255106

api.md

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,18 @@ Methods:
7575

7676
## Versions
7777

78+
Types:
79+
80+
```python
81+
from replicate.types.models import VersionCreateTrainingResponse
82+
```
83+
7884
Methods:
7985

8086
- <code title="get /models/{model_owner}/{model_name}/versions/{version_id}">client.models.versions.<a href="./src/replicate/resources/models/versions.py">retrieve</a>(version_id, \*, model_owner, model_name) -> None</code>
8187
- <code title="get /models/{model_owner}/{model_name}/versions">client.models.versions.<a href="./src/replicate/resources/models/versions.py">list</a>(model_name, \*, model_owner) -> None</code>
8288
- <code title="delete /models/{model_owner}/{model_name}/versions/{version_id}">client.models.versions.<a href="./src/replicate/resources/models/versions.py">delete</a>(version_id, \*, model_owner, model_name) -> None</code>
83-
- <code title="post /models/{model_owner}/{model_name}/versions/{version_id}/trainings">client.models.versions.<a href="./src/replicate/resources/models/versions.py">create_training</a>(version_id, \*, model_owner, model_name, \*\*<a href="src/replicate/types/models/version_create_training_params.py">params</a>) -> None</code>
89+
- <code title="post /models/{model_owner}/{model_name}/versions/{version_id}/trainings">client.models.versions.<a href="./src/replicate/resources/models/versions.py">create_training</a>(version_id, \*, model_owner, model_name, \*\*<a href="src/replicate/types/models/version_create_training_params.py">params</a>) -> <a href="./src/replicate/types/models/version_create_training_response.py">VersionCreateTrainingResponse</a></code>
8490

8591
# Predictions
8692

@@ -99,11 +105,17 @@ Methods:
99105

100106
# Trainings
101107

108+
Types:
109+
110+
```python
111+
from replicate.types import TrainingRetrieveResponse, TrainingListResponse, TrainingCancelResponse
112+
```
113+
102114
Methods:
103115

104-
- <code title="get /trainings/{training_id}">client.trainings.<a href="./src/replicate/resources/trainings.py">retrieve</a>(training_id) -> None</code>
105-
- <code title="get /trainings">client.trainings.<a href="./src/replicate/resources/trainings.py">list</a>() -> None</code>
106-
- <code title="post /trainings/{training_id}/cancel">client.trainings.<a href="./src/replicate/resources/trainings.py">cancel</a>(training_id) -> None</code>
116+
- <code title="get /trainings/{training_id}">client.trainings.<a href="./src/replicate/resources/trainings.py">retrieve</a>(training_id) -> <a href="./src/replicate/types/training_retrieve_response.py">TrainingRetrieveResponse</a></code>
117+
- <code title="get /trainings">client.trainings.<a href="./src/replicate/resources/trainings.py">list</a>() -> <a href="./src/replicate/types/training_list_response.py">SyncCursorURLPage[TrainingListResponse]</a></code>
118+
- <code title="post /trainings/{training_id}/cancel">client.trainings.<a href="./src/replicate/resources/trainings.py">cancel</a>(training_id) -> <a href="./src/replicate/types/training_cancel_response.py">TrainingCancelResponse</a></code>
107119

108120
# Webhooks
109121

src/replicate/resources/models/versions.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from ..._base_client import make_request_options
2424
from ...types.models import version_create_training_params
25+
from ...types.models.version_create_training_response import VersionCreateTrainingResponse
2526

2627
__all__ = ["VersionsResource", "AsyncVersionsResource"]
2728

@@ -281,7 +282,7 @@ def create_training(
281282
extra_query: Query | None = None,
282283
extra_body: Body | None = None,
283284
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
284-
) -> None:
285+
) -> VersionCreateTrainingResponse:
285286
"""
286287
Start a new training of the model version you specify.
287288
@@ -398,7 +399,6 @@ def create_training(
398399
raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
399400
if not version_id:
400401
raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
401-
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
402402
return self._post(
403403
f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings",
404404
body=maybe_transform(
@@ -413,7 +413,7 @@ def create_training(
413413
options=make_request_options(
414414
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
415415
),
416-
cast_to=NoneType,
416+
cast_to=VersionCreateTrainingResponse,
417417
)
418418

419419

@@ -672,7 +672,7 @@ async def create_training(
672672
extra_query: Query | None = None,
673673
extra_body: Body | None = None,
674674
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
675-
) -> None:
675+
) -> VersionCreateTrainingResponse:
676676
"""
677677
Start a new training of the model version you specify.
678678
@@ -789,7 +789,6 @@ async def create_training(
789789
raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
790790
if not version_id:
791791
raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
792-
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
793792
return await self._post(
794793
f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings",
795794
body=await async_maybe_transform(
@@ -804,7 +803,7 @@ async def create_training(
804803
options=make_request_options(
805804
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
806805
),
807-
cast_to=NoneType,
806+
cast_to=VersionCreateTrainingResponse,
808807
)
809808

810809

src/replicate/resources/trainings.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import httpx
66

7-
from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
7+
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
88
from .._compat import cached_property
99
from .._resource import SyncAPIResource, AsyncAPIResource
1010
from .._response import (
@@ -13,7 +13,11 @@
1313
async_to_raw_response_wrapper,
1414
async_to_streamed_response_wrapper,
1515
)
16-
from .._base_client import make_request_options
16+
from ..pagination import SyncCursorURLPage, AsyncCursorURLPage
17+
from .._base_client import AsyncPaginator, make_request_options
18+
from ..types.training_list_response import TrainingListResponse
19+
from ..types.training_cancel_response import TrainingCancelResponse
20+
from ..types.training_retrieve_response import TrainingRetrieveResponse
1721

1822
__all__ = ["TrainingsResource", "AsyncTrainingsResource"]
1923

@@ -48,7 +52,7 @@ def retrieve(
4852
extra_query: Query | None = None,
4953
extra_body: Body | None = None,
5054
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
51-
) -> None:
55+
) -> TrainingRetrieveResponse:
5256
"""
5357
Get the current state of a training.
5458
@@ -123,13 +127,12 @@ def retrieve(
123127
"""
124128
if not training_id:
125129
raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
126-
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
127130
return self._get(
128131
f"/trainings/{training_id}",
129132
options=make_request_options(
130133
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
131134
),
132-
cast_to=NoneType,
135+
cast_to=TrainingRetrieveResponse,
133136
)
134137

135138
def list(
@@ -141,7 +144,7 @@ def list(
141144
extra_query: Query | None = None,
142145
extra_body: Body | None = None,
143146
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
144-
) -> None:
147+
) -> SyncCursorURLPage[TrainingListResponse]:
145148
"""
146149
Get a paginated list of all trainings created by the user or organization
147150
associated with the provided API token.
@@ -207,13 +210,13 @@ def list(
207210
208211
`version` will be the unique ID of model version used to create the training.
209212
"""
210-
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
211-
return self._get(
213+
return self._get_api_list(
212214
"/trainings",
215+
page=SyncCursorURLPage[TrainingListResponse],
213216
options=make_request_options(
214217
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
215218
),
216-
cast_to=NoneType,
219+
model=TrainingListResponse,
217220
)
218221

219222
def cancel(
@@ -226,7 +229,7 @@ def cancel(
226229
extra_query: Query | None = None,
227230
extra_body: Body | None = None,
228231
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
229-
) -> None:
232+
) -> TrainingCancelResponse:
230233
"""
231234
Cancel a training
232235
@@ -241,13 +244,12 @@ def cancel(
241244
"""
242245
if not training_id:
243246
raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
244-
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
245247
return self._post(
246248
f"/trainings/{training_id}/cancel",
247249
options=make_request_options(
248250
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
249251
),
250-
cast_to=NoneType,
252+
cast_to=TrainingCancelResponse,
251253
)
252254

253255

@@ -281,7 +283,7 @@ async def retrieve(
281283
extra_query: Query | None = None,
282284
extra_body: Body | None = None,
283285
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
284-
) -> None:
286+
) -> TrainingRetrieveResponse:
285287
"""
286288
Get the current state of a training.
287289
@@ -356,16 +358,15 @@ async def retrieve(
356358
"""
357359
if not training_id:
358360
raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
359-
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
360361
return await self._get(
361362
f"/trainings/{training_id}",
362363
options=make_request_options(
363364
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
364365
),
365-
cast_to=NoneType,
366+
cast_to=TrainingRetrieveResponse,
366367
)
367368

368-
async def list(
369+
def list(
369370
self,
370371
*,
371372
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
@@ -374,7 +375,7 @@ async def list(
374375
extra_query: Query | None = None,
375376
extra_body: Body | None = None,
376377
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
377-
) -> None:
378+
) -> AsyncPaginator[TrainingListResponse, AsyncCursorURLPage[TrainingListResponse]]:
378379
"""
379380
Get a paginated list of all trainings created by the user or organization
380381
associated with the provided API token.
@@ -440,13 +441,13 @@ async def list(
440441
441442
`version` will be the unique ID of model version used to create the training.
442443
"""
443-
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
444-
return await self._get(
444+
return self._get_api_list(
445445
"/trainings",
446+
page=AsyncCursorURLPage[TrainingListResponse],
446447
options=make_request_options(
447448
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
448449
),
449-
cast_to=NoneType,
450+
model=TrainingListResponse,
450451
)
451452

452453
async def cancel(
@@ -459,7 +460,7 @@ async def cancel(
459460
extra_query: Query | None = None,
460461
extra_body: Body | None = None,
461462
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
462-
) -> None:
463+
) -> TrainingCancelResponse:
463464
"""
464465
Cancel a training
465466
@@ -474,13 +475,12 @@ async def cancel(
474475
"""
475476
if not training_id:
476477
raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
477-
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
478478
return await self._post(
479479
f"/trainings/{training_id}/cancel",
480480
options=make_request_options(
481481
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
482482
),
483-
cast_to=NoneType,
483+
cast_to=TrainingCancelResponse,
484484
)
485485

486486

src/replicate/types/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
from .account_list_response import AccountListResponse as AccountListResponse
1010
from .hardware_list_response import HardwareListResponse as HardwareListResponse
1111
from .prediction_list_params import PredictionListParams as PredictionListParams
12+
from .training_list_response import TrainingListResponse as TrainingListResponse
1213
from .deployment_create_params import DeploymentCreateParams as DeploymentCreateParams
1314
from .deployment_list_response import DeploymentListResponse as DeploymentListResponse
1415
from .deployment_update_params import DeploymentUpdateParams as DeploymentUpdateParams
1516
from .prediction_create_params import PredictionCreateParams as PredictionCreateParams
17+
from .training_cancel_response import TrainingCancelResponse as TrainingCancelResponse
1618
from .deployment_create_response import DeploymentCreateResponse as DeploymentCreateResponse
1719
from .deployment_update_response import DeploymentUpdateResponse as DeploymentUpdateResponse
20+
from .training_retrieve_response import TrainingRetrieveResponse as TrainingRetrieveResponse
1821
from .deployment_retrieve_response import DeploymentRetrieveResponse as DeploymentRetrieveResponse
1922
from .model_create_prediction_params import ModelCreatePredictionParams as ModelCreatePredictionParams

src/replicate/types/deployment_list_response.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,7 @@ class CurrentRelease(BaseModel):
4949
"""The model identifier string in the format of `{model_owner}/{model_name}`."""
5050

5151
number: Optional[int] = None
52-
"""The release number.
53-
54-
This is an auto-incrementing integer that starts at 1, and is set automatically
55-
when a deployment is created.
56-
"""
52+
"""The release number."""
5753

5854
version: Optional[str] = None
5955
"""The ID of the model version used in the release."""

src/replicate/types/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from __future__ import annotations
44

55
from .version_create_training_params import VersionCreateTrainingParams as VersionCreateTrainingParams
6+
from .version_create_training_response import VersionCreateTrainingResponse as VersionCreateTrainingResponse
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
3+
from typing import Dict, Optional
4+
from datetime import datetime
5+
from typing_extensions import Literal
6+
7+
from ..._models import BaseModel
8+
9+
__all__ = ["VersionCreateTrainingResponse", "Metrics", "Output", "URLs"]
10+
11+
12+
class Metrics(BaseModel):
13+
predict_time: Optional[float] = None
14+
"""The amount of CPU or GPU time, in seconds, that the training used while running"""
15+
16+
17+
class Output(BaseModel):
18+
version: Optional[str] = None
19+
"""The version of the model created by the training"""
20+
21+
weights: Optional[str] = None
22+
"""The weights of the trained model"""
23+
24+
25+
class URLs(BaseModel):
26+
cancel: Optional[str] = None
27+
"""URL to cancel the training"""
28+
29+
get: Optional[str] = None
30+
"""URL to get the training details"""
31+
32+
33+
class VersionCreateTrainingResponse(BaseModel):
34+
id: Optional[str] = None
35+
"""The unique ID of the training"""
36+
37+
completed_at: Optional[datetime] = None
38+
"""The time when the training completed"""
39+
40+
created_at: Optional[datetime] = None
41+
"""The time when the training was created"""
42+
43+
error: Optional[str] = None
44+
"""Error message if the training failed"""
45+
46+
input: Optional[Dict[str, object]] = None
47+
"""The input parameters used for the training"""
48+
49+
logs: Optional[str] = None
50+
"""The logs from the training process"""
51+
52+
metrics: Optional[Metrics] = None
53+
"""Metrics about the training process"""
54+
55+
model: Optional[str] = None
56+
"""The name of the model in the format owner/name"""
57+
58+
output: Optional[Output] = None
59+
"""The output of the training process"""
60+
61+
source: Optional[Literal["web", "api"]] = None
62+
"""How the training was created"""
63+
64+
started_at: Optional[datetime] = None
65+
"""The time when the training started"""
66+
67+
status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None
68+
"""The current status of the training"""
69+
70+
urls: Optional[URLs] = None
71+
"""URLs for interacting with the training"""
72+
73+
version: Optional[str] = None
74+
"""The ID of the model version used for training"""

src/replicate/types/prediction_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
__all__ = ["PredictionOutput"]
77

88
PredictionOutput: TypeAlias = Union[
9-
Optional[Dict[str, object]], Optional[List[object]], Optional[str], Optional[float], Optional[bool]
9+
Optional[Dict[str, object]], Optional[List[Dict[str, object]]], Optional[str], Optional[float], Optional[bool]
1010
]

0 commit comments

Comments
 (0)