diff --git a/src/together/resources/batch.py b/src/together/resources/batch.py index e9f3065..ab8a900 100644 --- a/src/together/resources/batch.py +++ b/src/together/resources/batch.py @@ -72,6 +72,21 @@ def list_batches(self) -> List[BatchJob]: jobs = response.data or [] return [BatchJob(**job) for job in jobs] + def cancel_batch(self, batch_job_id: str) -> BatchJob: + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + response, _, _ = requestor.request( + options=TogetherRequest( + method="POST", + url=f"batches/{batch_job_id}/cancel", + ), + stream=False, + ) + + return BatchJob(**response.data) + class AsyncBatches: def __init__(self, client: TogetherClient) -> None: @@ -133,3 +148,18 @@ async def list_batches(self) -> List[BatchJob]: assert isinstance(response, TogetherResponse) jobs = response.data or [] return [BatchJob(**job) for job in jobs] + + async def cancel_batch(self, batch_job_id: str) -> BatchJob: + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + response, _, _ = await requestor.arequest( + options=TogetherRequest( + method="POST", + url=f"batches/{batch_job_id}/cancel", + ), + stream=False, + ) + + return BatchJob(**response.data) diff --git a/src/together/types/batch.py b/src/together/types/batch.py index 3bd5ead..aaec0e9 100644 --- a/src/together/types/batch.py +++ b/src/together/types/batch.py @@ -20,6 +20,7 @@ class BatchJobStatus(str, Enum): FAILED = "FAILED" EXPIRED = "EXPIRED" CANCELLED = "CANCELLED" + CANCELING = "CANCELING" class BatchEndpoint(str, Enum):