diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 00000000..389eee3f --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,8 @@ +FROM mcr.microsoft.com/devcontainers/python:3.9 + +ENV PYTHONUNBUFFERED 1 + +RUN pipx install poetry==1.8.3 + +# Install pre-commit +RUN pip install pre-commit diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..ccae33e5 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,40 @@ +{ + "name": "Together Python Development", + "build": { + "dockerfile": "Dockerfile" + }, + "features": { + "ghcr.io/devcontainers/features/git:1": {}, + "ghcr.io/devcontainers/features/node:1": {}, + "ghcr.io/devcontainers/features/java:1": { + "version": "17", + "installMaven": false, + "installGradle": false + } + }, + "customizations": { + "vscode": { + "extensions": [ + "ms-python.python", + "ms-python.vscode-pylance", + "ms-python.isort", + "charliermarsh.ruff", + "ms-python.mypy-type-checker", + "eamodio.gitlens" + ], + "settings": { + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff" + }, + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + }, + "ruff.lineLength": 100 + } + } + }, + "postCreateCommand": "poetry install", + "remoteUser": "vscode" +} diff --git a/.github/workflows/_integration_tests.yml b/.github/workflows/_integration_tests.yml index c035debe..5b6813b9 100644 --- a/.github/workflows/_integration_tests.yml +++ b/.github/workflows/_integration_tests.yml @@ -24,7 +24,6 @@ jobs: strategy: matrix: python-version: - - "3.8" - "3.9" - "3.10" - "3.11" diff --git a/.github/workflows/_tests.yml b/.github/workflows/_tests.yml index 851320dd..62e42844 100644 --- a/.github/workflows/_tests.yml +++ b/.github/workflows/_tests.yml @@ -24,7 +24,6 @@ jobs: strategy: matrix: python-version: - - "3.8" - "3.9" - "3.10" - "3.11" diff --git a/pyproject.toml b/pyproject.toml index 0a4534cd..be4a95ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api" [tool.poetry] name = "together" -version = "1.4.0" +version = "1.4.1" authors = [ "Together AI " ] diff --git a/src/together/abstract/api_requestor.py b/src/together/abstract/api_requestor.py index e4004f3e..7e37eaf8 100644 --- a/src/together/abstract/api_requestor.py +++ b/src/together/abstract/api_requestor.py @@ -437,7 +437,7 @@ def _prepare_request_raw( [(k, v) for k, v in options.params.items() if v is not None] ) abs_url = _build_api_url(abs_url, encoded_params) - elif options.method.lower() in {"post", "put"}: + elif options.method.lower() in {"post", "put", "patch"}: if options.params and (options.files or options.override_headers): data = options.params elif options.params and not options.files: @@ -587,16 +587,14 @@ async def arequest_raw( ) headers["Content-Type"] = content_type - request_kwargs = { - "headers": headers, - "data": data, - "timeout": timeout, - "allow_redirects": options.allow_redirects, - } - try: result = await session.request( - method=options.method, url=abs_url, **request_kwargs + method=options.method, + url=abs_url, + headers=headers, + data=data, + timeout=timeout, + allow_redirects=options.allow_redirects, ) utils.log_debug( "Together API response", diff --git a/src/together/cli/api/endpoints.py b/src/together/cli/api/endpoints.py new file mode 100644 index 00000000..3d306063 --- /dev/null +++ b/src/together/cli/api/endpoints.py @@ -0,0 +1,415 @@ +from __future__ import annotations + +import json +import sys +from functools import wraps +from typing import Any, Callable, Dict, List, Literal, TypeVar, Union + +import click + +from together import Together +from together.error import InvalidRequestError +from together.types import DedicatedEndpoint, ListEndpoint + + +def print_endpoint( + endpoint: Union[DedicatedEndpoint, ListEndpoint], +) -> None: + """Print endpoint details in a Docker-like format or JSON.""" + + # Print header info + click.echo(f"ID:\t\t{endpoint.id}") + click.echo(f"Name:\t\t{endpoint.name}") + + # Print type-specific fields + if isinstance(endpoint, DedicatedEndpoint): + click.echo(f"Display Name:\t{endpoint.display_name}") + click.echo(f"Hardware:\t{endpoint.hardware}") + click.echo( + f"Autoscaling:\tMin={endpoint.autoscaling.min_replicas}, " + f"Max={endpoint.autoscaling.max_replicas}" + ) + + click.echo(f"Model:\t\t{endpoint.model}") + click.echo(f"Type:\t\t{endpoint.type}") + click.echo(f"Owner:\t\t{endpoint.owner}") + click.echo(f"State:\t\t{endpoint.state}") + click.echo(f"Created:\t{endpoint.created_at}") + + +F = TypeVar("F", bound=Callable[..., Any]) + + +def print_api_error( + e: InvalidRequestError, +) -> None: + error_details = e.api_response.message + + if error_details and ( + "credentials" in error_details.lower() + or "authentication" in error_details.lower() + ): + click.echo("Error: Invalid API key or authentication failed", err=True) + else: + click.echo(f"Error: {error_details}", err=True) + + +def handle_api_errors(f: F) -> F: + """Decorator to handle common API errors in CLI commands.""" + + @wraps(f) + def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return f(*args, **kwargs) + except InvalidRequestError as e: + print_api_error(e) + sys.exit(1) + except Exception as e: + click.echo(f"Error: An unexpected error occurred - {str(e)}", err=True) + sys.exit(1) + + return wrapper # type: ignore + + +@click.group() +@click.pass_context +def endpoints(ctx: click.Context) -> None: + """Endpoints API commands""" + pass + + +@endpoints.command() +@click.option( + "--model", + required=True, + help="The model to deploy (e.g. mistralai/Mixtral-8x7B-Instruct-v0.1)", +) +@click.option( + "--min-replicas", + type=int, + default=1, + help="Minimum number of replicas to deploy", +) +@click.option( + "--max-replicas", + type=int, + default=1, + help="Maximum number of replicas to deploy", +) +@click.option( + "--gpu", + type=click.Choice(["h100", "a100", "l40", "l40s", "rtx-6000"]), + required=True, + help="GPU type to use for inference", +) +@click.option( + "--gpu-count", + type=int, + default=1, + help="Number of GPUs to use per replica", +) +@click.option( + "--display-name", + help="A human-readable name for the endpoint", +) +@click.option( + "--no-prompt-cache", + is_flag=True, + help="Disable the prompt cache for this endpoint", +) +@click.option( + "--no-speculative-decoding", + is_flag=True, + help="Disable speculative decoding for this endpoint", +) +@click.option( + "--no-auto-start", + is_flag=True, + help="Create the endpoint in STOPPED state instead of auto-starting it", +) +@click.option( + "--wait", + is_flag=True, + default=True, + help="Wait for the endpoint to be ready after creation", +) +@click.pass_obj +@handle_api_errors +def create( + client: Together, + model: str, + min_replicas: int, + max_replicas: int, + gpu: str, + gpu_count: int, + display_name: str | None, + no_prompt_cache: bool, + no_speculative_decoding: bool, + no_auto_start: bool, + wait: bool, +) -> None: + """Create a new dedicated inference endpoint.""" + # Map GPU types to their full hardware ID names + gpu_map = { + "h100": "nvidia_h100_80gb_sxm", + "a100": "nvidia_a100_80gb_pcie" if gpu_count == 1 else "nvidia_a100_80gb_sxm", + "l40": "nvidia_l40", + "l40s": "nvidia_l40s", + "rtx-6000": "nvidia_rtx_6000_ada", + } + + hardware_id = f"{gpu_count}x_{gpu_map[gpu]}" + + try: + response = client.endpoints.create( + model=model, + hardware=hardware_id, + min_replicas=min_replicas, + max_replicas=max_replicas, + display_name=display_name, + disable_prompt_cache=no_prompt_cache, + disable_speculative_decoding=no_speculative_decoding, + state="STOPPED" if no_auto_start else "STARTED", + ) + except InvalidRequestError as e: + print_api_error(e) + if "check the hardware api" in str(e).lower(): + fetch_and_print_hardware_options( + client=client, model=model, print_json=False, available=True + ) + + sys.exit(1) + + # Print detailed information to stderr + click.echo("Created dedicated endpoint with:", err=True) + click.echo(f" Model: {model}", err=True) + click.echo(f" Min replicas: {min_replicas}", err=True) + click.echo(f" Max replicas: {max_replicas}", err=True) + click.echo(f" Hardware: {hardware_id}", err=True) + if display_name: + click.echo(f" Display name: {display_name}", err=True) + if no_prompt_cache: + click.echo(" Prompt cache: disabled", err=True) + if no_speculative_decoding: + click.echo(" Speculative decoding: disabled", err=True) + if no_auto_start: + click.echo(" Auto-start: disabled", err=True) + + click.echo(f"Endpoint created successfully, id: {response.id}", err=True) + + if wait: + import time + + click.echo("Waiting for endpoint to be ready...", err=True) + while client.endpoints.get(response.id).state != "STARTED": + time.sleep(1) + click.echo("Endpoint ready", err=True) + + # Print only the endpoint ID to stdout + click.echo(response.id) + + +@endpoints.command() +@click.argument("endpoint-id", required=True) +@click.option("--json", is_flag=True, help="Print output in JSON format") +@click.pass_obj +@handle_api_errors +def get(client: Together, endpoint_id: str, json: bool) -> None: + """Get a dedicated inference endpoint.""" + endpoint = client.endpoints.get(endpoint_id) + if json: + import json as json_lib + + click.echo(json_lib.dumps(endpoint.model_dump(), indent=2)) + else: + print_endpoint(endpoint) + + +@endpoints.command() +@click.option("--model", help="Filter hardware options by model") +@click.option("--json", is_flag=True, help="Print output in JSON format") +@click.option( + "--available", + is_flag=True, + help="Print only available hardware options (can only be used if model is passed in)", +) +@click.pass_obj +@handle_api_errors +def hardware(client: Together, model: str | None, json: bool, available: bool) -> None: + """List all available hardware options, optionally filtered by model.""" + fetch_and_print_hardware_options(client, model, json, available) + + +def fetch_and_print_hardware_options( + client: Together, model: str | None, print_json: bool, available: bool +) -> None: + """Print hardware options for a model.""" + + message = "Available hardware options:" if available else "All hardware options:" + click.echo(message, err=True) + hardware_options = client.endpoints.list_hardware(model) + if available: + hardware_options = [ + hardware + for hardware in hardware_options + if hardware.availability is not None + and hardware.availability.status == "available" + ] + + if print_json: + json_output = [hardware.model_dump() for hardware in hardware_options] + click.echo(json.dumps(json_output, indent=2)) + else: + for hardware in hardware_options: + click.echo(f" {hardware.id}", err=True) + + +@endpoints.command() +@click.argument("endpoint-id", required=True) +@click.option( + "--wait", is_flag=True, default=True, help="Wait for the endpoint to stop" +) +@click.pass_obj +@handle_api_errors +def stop(client: Together, endpoint_id: str, wait: bool) -> None: + """Stop a dedicated inference endpoint.""" + client.endpoints.update(endpoint_id, state="STOPPED") + click.echo("Successfully marked endpoint as stopping", err=True) + + if wait: + import time + + click.echo("Waiting for endpoint to stop...", err=True) + while client.endpoints.get(endpoint_id).state != "STOPPED": + time.sleep(1) + click.echo("Endpoint stopped", err=True) + + click.echo(endpoint_id) + + +@endpoints.command() +@click.argument("endpoint-id", required=True) +@click.option( + "--wait", is_flag=True, default=True, help="Wait for the endpoint to start" +) +@click.pass_obj +@handle_api_errors +def start(client: Together, endpoint_id: str, wait: bool) -> None: + """Start a dedicated inference endpoint.""" + client.endpoints.update(endpoint_id, state="STARTED") + click.echo("Successfully marked endpoint as starting", err=True) + + if wait: + import time + + click.echo("Waiting for endpoint to start...", err=True) + while client.endpoints.get(endpoint_id).state != "STARTED": + time.sleep(1) + click.echo("Endpoint started", err=True) + + click.echo(endpoint_id) + + +@endpoints.command() +@click.argument("endpoint-id", required=True) +@click.pass_obj +@handle_api_errors +def delete(client: Together, endpoint_id: str) -> None: + """Delete a dedicated inference endpoint.""" + client.endpoints.delete(endpoint_id) + click.echo("Successfully deleted endpoint", err=True) + click.echo(endpoint_id) + + +@endpoints.command() +@click.option("--json", is_flag=True, help="Print output in JSON format") +@click.option( + "--type", + type=click.Choice(["dedicated", "serverless"]), + help="Filter by endpoint type", +) +@click.pass_obj +@handle_api_errors +def list( + client: Together, json: bool, type: Literal["dedicated", "serverless"] | None +) -> None: + """List all inference endpoints (includes both dedicated and serverless endpoints).""" + endpoints: List[ListEndpoint] = client.endpoints.list(type=type) + + if not endpoints: + click.echo("No dedicated endpoints found", err=True) + return + + click.echo("Endpoints:", err=True) + if json: + import json as json_lib + + click.echo( + json_lib.dumps([endpoint.model_dump() for endpoint in endpoints], indent=2) + ) + else: + for endpoint in endpoints: + print_endpoint( + endpoint, + ) + click.echo() + + +@endpoints.command() +@click.argument("endpoint-id", required=True) +@click.option( + "--display-name", + help="A new human-readable name for the endpoint", +) +@click.option( + "--min-replicas", + type=int, + help="New minimum number of replicas to maintain", +) +@click.option( + "--max-replicas", + type=int, + help="New maximum number of replicas to scale up to", +) +@click.pass_obj +@handle_api_errors +def update( + client: Together, + endpoint_id: str, + display_name: str | None, + min_replicas: int | None, + max_replicas: int | None, +) -> None: + """Update a dedicated inference endpoint's configuration.""" + if not any([display_name, min_replicas, max_replicas]): + click.echo("Error: At least one update option must be specified", err=True) + sys.exit(1) + + # If only one of min/max replicas is specified, we need both for the update + if (min_replicas is None) != (max_replicas is None): + click.echo( + "Error: Both --min-replicas and --max-replicas must be specified together", + err=True, + ) + sys.exit(1) + + # Build kwargs for the update + kwargs: Dict[str, Any] = {} + if display_name is not None: + kwargs["display_name"] = display_name + if min_replicas is not None and max_replicas is not None: + kwargs["min_replicas"] = min_replicas + kwargs["max_replicas"] = max_replicas + + _response = client.endpoints.update(endpoint_id, **kwargs) + + # Print what was updated + click.echo("Updated endpoint configuration:", err=True) + if display_name: + click.echo(f" Display name: {display_name}", err=True) + if min_replicas is not None and max_replicas is not None: + click.echo(f" Min replicas: {min_replicas}", err=True) + click.echo(f" Max replicas: {max_replicas}", err=True) + + click.echo("Successfully updated endpoint", err=True) + click.echo(endpoint_id) diff --git a/src/together/cli/cli.py b/src/together/cli/cli.py index 8bfee0db..7ae35121 100644 --- a/src/together/cli/cli.py +++ b/src/together/cli/cli.py @@ -8,6 +8,7 @@ import together from together.cli.api.chat import chat, interactive from together.cli.api.completions import completions +from together.cli.api.endpoints import endpoints from together.cli.api.files import files from together.cli.api.finetune import fine_tuning from together.cli.api.images import images @@ -72,6 +73,7 @@ def main( main.add_command(files) main.add_command(fine_tuning) main.add_command(models) +main.add_command(endpoints) if __name__ == "__main__": main() diff --git a/src/together/client.py b/src/together/client.py index 6419581b..ea5359a5 100644 --- a/src/together/client.py +++ b/src/together/client.py @@ -81,6 +81,7 @@ def __init__( self.fine_tuning = resources.FineTuning(self.client) self.rerank = resources.Rerank(self.client) self.audio = resources.Audio(self.client) + self.endpoints = resources.Endpoints(self.client) class AsyncTogether: diff --git a/src/together/error.py b/src/together/error.py index b5bdfd40..e2883a2c 100644 --- a/src/together/error.py +++ b/src/together/error.py @@ -18,6 +18,9 @@ def __init__( request_id: str | None = None, http_status: int | None = None, ) -> None: + if isinstance(message, TogetherErrorResponse): + self.api_response = message + _message = ( json.dumps(message.model_dump(exclude_none=True)) if isinstance(message, TogetherErrorResponse) diff --git a/src/together/resources/__init__.py b/src/together/resources/__init__.py index cf4bf3b2..f07aeb00 100644 --- a/src/together/resources/__init__.py +++ b/src/together/resources/__init__.py @@ -1,12 +1,13 @@ +from together.resources.audio import AsyncAudio, Audio from together.resources.chat import AsyncChat, Chat from together.resources.completions import AsyncCompletions, Completions from together.resources.embeddings import AsyncEmbeddings, Embeddings +from together.resources.endpoints import AsyncEndpoints, Endpoints from together.resources.files import AsyncFiles, Files from together.resources.finetune import AsyncFineTuning, FineTuning from together.resources.images import AsyncImages, Images from together.resources.models import AsyncModels, Models from together.resources.rerank import AsyncRerank, Rerank -from together.resources.audio import AsyncAudio, Audio __all__ = [ @@ -28,4 +29,6 @@ "Rerank", "AsyncAudio", "Audio", + "AsyncEndpoints", + "Endpoints", ] diff --git a/src/together/resources/endpoints.py b/src/together/resources/endpoints.py new file mode 100644 index 00000000..176894f5 --- /dev/null +++ b/src/together/resources/endpoints.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +from typing import Dict, List, Literal, Optional, Union + +from together.abstract import api_requestor +from together.together_response import TogetherResponse +from together.types import TogetherClient, TogetherRequest +from together.types.endpoints import DedicatedEndpoint, HardwareWithStatus, ListEndpoint + + +class Endpoints: + def __init__(self, client: TogetherClient) -> None: + self._client = client + + def list( + self, type: Optional[Literal["dedicated", "serverless"]] = None + ) -> List[ListEndpoint]: + """ + List all endpoints, can be filtered by type. + + Args: + type (str, optional): Filter endpoints by type ("dedicated" or "serverless"). Defaults to None. + + Returns: + List[ListEndpoint]: List of endpoint objects + """ + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + params = {} + if type is not None: + params["type"] = type + + response, _, _ = requestor.request( + options=TogetherRequest( + method="GET", + url="endpoints", + params=params, + ), + stream=False, + ) + + response.data = response.data["data"] + + assert isinstance(response, TogetherResponse) + assert isinstance(response.data, list) + + return [ListEndpoint(**endpoint) for endpoint in response.data] + + def create( + self, + *, + model: str, + hardware: str, + min_replicas: int, + max_replicas: int, + display_name: Optional[str] = None, + disable_prompt_cache: bool = False, + disable_speculative_decoding: bool = False, + state: Literal["STARTED", "STOPPED"] = "STARTED", + ) -> DedicatedEndpoint: + """ + Create a new dedicated endpoint. + + Args: + model (str): The model to deploy on this endpoint + hardware (str): The hardware configuration to use for this endpoint + min_replicas (int): The minimum number of replicas to maintain + max_replicas (int): The maximum number of replicas to scale up to + display_name (str, optional): A human-readable name for the endpoint + disable_prompt_cache (bool, optional): Whether to disable the prompt cache. Defaults to False. + disable_speculative_decoding (bool, optional): Whether to disable speculative decoding. Defaults to False. + state (str, optional): The desired state of the endpoint. Defaults to "STARTED". + + Returns: + DedicatedEndpoint: Object containing endpoint information + """ + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + data: Dict[str, Union[str, bool, Dict[str, int]]] = { + "model": model, + "hardware": hardware, + "autoscaling": { + "min_replicas": min_replicas, + "max_replicas": max_replicas, + }, + "disable_prompt_cache": disable_prompt_cache, + "disable_speculative_decoding": disable_speculative_decoding, + "state": state, + } + + if display_name is not None: + data["display_name"] = display_name + + response, _, _ = requestor.request( + options=TogetherRequest( + method="POST", + url="endpoints", + params=data, + ), + stream=False, + ) + + assert isinstance(response, TogetherResponse) + + return DedicatedEndpoint(**response.data) + + def get(self, endpoint_id: str) -> DedicatedEndpoint: + """ + Get details of a specific endpoint. + + Args: + endpoint_id (str): ID of the endpoint to retrieve + + Returns: + DedicatedEndpoint: Object containing endpoint information + """ + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + response, _, _ = requestor.request( + options=TogetherRequest( + method="GET", + url=f"endpoints/{endpoint_id}", + ), + stream=False, + ) + + assert isinstance(response, TogetherResponse) + + return DedicatedEndpoint(**response.data) + + def delete(self, endpoint_id: str) -> None: + """ + Delete a specific endpoint. + + Args: + endpoint_id (str): ID of the endpoint to delete + """ + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + requestor.request( + options=TogetherRequest( + method="DELETE", + url=f"endpoints/{endpoint_id}", + ), + stream=False, + ) + + def update( + self, + endpoint_id: str, + *, + min_replicas: Optional[int] = None, + max_replicas: Optional[int] = None, + state: Optional[Literal["STARTED", "STOPPED"]] = None, + display_name: Optional[str] = None, + ) -> DedicatedEndpoint: + """ + Update an endpoint's configuration. + + Args: + endpoint_id (str): ID of the endpoint to update + min_replicas (int, optional): The minimum number of replicas to maintain + max_replicas (int, optional): The maximum number of replicas to scale up to + state (str, optional): The desired state of the endpoint ("STARTED" or "STOPPED") + display_name (str, optional): A human-readable name for the endpoint + + Returns: + DedicatedEndpoint: Object containing endpoint information + """ + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + data: Dict[str, Union[str, Dict[str, int]]] = {} + + if min_replicas is not None or max_replicas is not None: + current_min = min_replicas + current_max = max_replicas + if current_min is None or current_max is None: + # Get current values if only one is specified + current = self.get(endpoint_id=endpoint_id) + current_min = current_min or current.autoscaling.min_replicas + current_max = current_max or current.autoscaling.max_replicas + data["autoscaling"] = { + "min_replicas": current_min, + "max_replicas": current_max, + } + + if state is not None: + data["state"] = state + + if display_name is not None: + data["display_name"] = display_name + + response, _, _ = requestor.request( + options=TogetherRequest( + method="PATCH", + url=f"endpoints/{endpoint_id}", + params=data, + ), + stream=False, + ) + + assert isinstance(response, TogetherResponse) + + return DedicatedEndpoint(**response.data) + + def list_hardware(self, model: Optional[str] = None) -> List[HardwareWithStatus]: + """ + List available hardware configurations. + + Args: + model (str, optional): Filter hardware configurations by model compatibility. When provided, + the response includes availability status for each compatible configuration. + + Returns: + List[HardwareWithStatus]: List of hardware configurations with their status + """ + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + params = {} + if model is not None: + params["model"] = model + + response, _, _ = requestor.request( + options=TogetherRequest( + method="GET", + url="hardware", + params=params, + ), + stream=False, + ) + + assert isinstance(response, TogetherResponse) + assert isinstance(response.data, dict) + assert isinstance(response.data["data"], list) + + return [HardwareWithStatus(**item) for item in response.data["data"]] + + +class AsyncEndpoints: + def __init__(self, client: TogetherClient) -> None: + self._client = client + + async def list( + self, type: Optional[Literal["dedicated", "serverless"]] = None + ) -> List[ListEndpoint]: + """ + List all endpoints, can be filtered by type. + + Args: + type (str, optional): Filter endpoints by type ("dedicated" or "serverless"). Defaults to None. + + Returns: + List[ListEndpoint]: List of endpoint objects + """ + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + params = {} + if type is not None: + params["type"] = type + + response, _, _ = await requestor.arequest( + options=TogetherRequest( + method="GET", + url="endpoints", + params=params, + ), + stream=False, + ) + + assert isinstance(response, TogetherResponse) + assert isinstance(response.data, list) + + return [ListEndpoint(**endpoint) for endpoint in response.data] + + async def create( + self, + *, + model: str, + hardware: str, + min_replicas: int, + max_replicas: int, + display_name: Optional[str] = None, + disable_prompt_cache: bool = False, + disable_speculative_decoding: bool = False, + state: Literal["STARTED", "STOPPED"] = "STARTED", + ) -> DedicatedEndpoint: + """ + Create a new dedicated endpoint. + + Args: + model (str): The model to deploy on this endpoint + hardware (str): The hardware configuration to use for this endpoint + min_replicas (int): The minimum number of replicas to maintain + max_replicas (int): The maximum number of replicas to scale up to + display_name (str, optional): A human-readable name for the endpoint + disable_prompt_cache (bool, optional): Whether to disable the prompt cache. Defaults to False. + disable_speculative_decoding (bool, optional): Whether to disable speculative decoding. Defaults to False. + state (str, optional): The desired state of the endpoint. Defaults to "STARTED". + + Returns: + DedicatedEndpoint: Object containing endpoint information + """ + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + data: Dict[str, Union[str, bool, Dict[str, int]]] = { + "model": model, + "hardware": hardware, + "autoscaling": { + "min_replicas": min_replicas, + "max_replicas": max_replicas, + }, + "disable_prompt_cache": disable_prompt_cache, + "disable_speculative_decoding": disable_speculative_decoding, + "state": state, + } + + if display_name is not None: + data["display_name"] = display_name + + response, _, _ = await requestor.arequest( + options=TogetherRequest( + method="POST", + url="endpoints", + params=data, + ), + stream=False, + ) + + assert isinstance(response, TogetherResponse) + + return DedicatedEndpoint(**response.data) + + async def get(self, endpoint_id: str) -> DedicatedEndpoint: + """ + Get details of a specific endpoint. + + Args: + endpoint_id (str): ID of the endpoint to retrieve + + Returns: + DedicatedEndpoint: Object containing endpoint information + """ + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + response, _, _ = await requestor.arequest( + options=TogetherRequest( + method="GET", + url=f"endpoints/{endpoint_id}", + ), + stream=False, + ) + + assert isinstance(response, TogetherResponse) + + return DedicatedEndpoint(**response.data) + + async def delete(self, endpoint_id: str) -> None: + """ + Delete a specific endpoint. + + Args: + endpoint_id (str): ID of the endpoint to delete + """ + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + await requestor.arequest( + options=TogetherRequest( + method="DELETE", + url=f"endpoints/{endpoint_id}", + ), + stream=False, + ) + + async def update( + self, + endpoint_id: str, + *, + min_replicas: Optional[int] = None, + max_replicas: Optional[int] = None, + state: Optional[Literal["STARTED", "STOPPED"]] = None, + display_name: Optional[str] = None, + ) -> DedicatedEndpoint: + """ + Update an endpoint's configuration. + + Args: + endpoint_id (str): ID of the endpoint to update + min_replicas (int, optional): The minimum number of replicas to maintain + max_replicas (int, optional): The maximum number of replicas to scale up to + state (str, optional): The desired state of the endpoint ("STARTED" or "STOPPED") + display_name (str, optional): A human-readable name for the endpoint + + Returns: + DedicatedEndpoint: Object containing endpoint information + """ + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + data: Dict[str, Union[str, Dict[str, int]]] = {} + + if min_replicas is not None or max_replicas is not None: + current_min = min_replicas + current_max = max_replicas + if current_min is None or current_max is None: + # Get current values if only one is specified + current = await self.get(endpoint_id=endpoint_id) + current_min = current_min or current.autoscaling.min_replicas + current_max = current_max or current.autoscaling.max_replicas + data["autoscaling"] = { + "min_replicas": current_min, + "max_replicas": current_max, + } + + if state is not None: + data["state"] = state + + if display_name is not None: + data["display_name"] = display_name + + response, _, _ = await requestor.arequest( + options=TogetherRequest( + method="PATCH", + url=f"endpoints/{endpoint_id}", + params=data, + ), + stream=False, + ) + + assert isinstance(response, TogetherResponse) + + return DedicatedEndpoint(**response.data) + + async def list_hardware( + self, model: Optional[str] = None + ) -> List[HardwareWithStatus]: + """ + List available hardware configurations. + + Args: + model (str, optional): Filter hardware configurations by model compatibility. When provided, + the response includes availability status for each compatible configuration. + + Returns: + List[HardwareWithStatus]: List of hardware configurations with their status + """ + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + params = {} + if model is not None: + params["model"] = model + + response, _, _ = await requestor.arequest( + options=TogetherRequest( + method="GET", + url="hardware", + params=params, + ), + stream=False, + ) + + assert isinstance(response, TogetherResponse) + assert isinstance(response.data, dict) + assert isinstance(response.data["data"], list) + + return [HardwareWithStatus(**item) for item in response.data["data"]] diff --git a/src/together/types/__init__.py b/src/together/types/__init__.py index 5768d8de..c3100cd1 100644 --- a/src/together/types/__init__.py +++ b/src/together/types/__init__.py @@ -1,4 +1,13 @@ from together.types.abstract import TogetherClient +from together.types.audio_speech import ( + AudioLanguage, + AudioResponseEncoding, + AudioResponseFormat, + AudioSpeechRequest, + AudioSpeechStreamChunk, + AudioSpeechStreamEvent, + AudioSpeechStreamResponse, +) from together.types.chat_completions import ( ChatCompletionChunk, ChatCompletionRequest, @@ -11,6 +20,7 @@ CompletionResponse, ) from together.types.embeddings import EmbeddingRequest, EmbeddingResponse +from together.types.endpoints import Autoscaling, DedicatedEndpoint, ListEndpoint from together.types.files import ( FileDeleteResponse, FileList, @@ -22,35 +32,21 @@ ) from together.types.finetune import ( FinetuneDownloadResult, + FinetuneLinearLRSchedulerArgs, FinetuneList, FinetuneListEvents, + FinetuneLRScheduler, FinetuneRequest, FinetuneResponse, + FinetuneTrainingLimits, FullTrainingType, LoRATrainingType, TrainingType, - FinetuneTrainingLimits, - FinetuneLRScheduler, - FinetuneLinearLRSchedulerArgs, -) -from together.types.images import ( - ImageRequest, - ImageResponse, ) +from together.types.images import ImageRequest, ImageResponse from together.types.models import ModelObject -from together.types.rerank import ( - RerankRequest, - RerankResponse, -) -from together.types.audio_speech import ( - AudioSpeechRequest, - AudioResponseFormat, - AudioLanguage, - AudioResponseEncoding, - AudioSpeechStreamChunk, - AudioSpeechStreamEvent, - AudioSpeechStreamResponse, -) +from together.types.rerank import RerankRequest, RerankResponse + __all__ = [ "TogetherClient", @@ -93,4 +89,7 @@ "AudioSpeechStreamChunk", "AudioSpeechStreamEvent", "AudioSpeechStreamResponse", + "DedicatedEndpoint", + "ListEndpoint", + "Autoscaling", ] diff --git a/src/together/types/endpoints.py b/src/together/types/endpoints.py new file mode 100644 index 00000000..22ab1934 --- /dev/null +++ b/src/together/types/endpoints.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, Literal, Optional, Union + +from pydantic import BaseModel, Field + + +class TogetherJSONModel(BaseModel): + """Base model with JSON serialization support.""" + + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + exclude_none = kwargs.pop("exclude_none", True) + data = super().model_dump(exclude_none=exclude_none, **kwargs) + + # Convert datetime objects to ISO format strings + for key, value in data.items(): + if isinstance(value, datetime): + data[key] = value.isoformat() + + return data + + +class Autoscaling(TogetherJSONModel): + """Configuration for automatic scaling of replicas based on demand.""" + + min_replicas: int = Field( + description="The minimum number of replicas to maintain, even when there is no load" + ) + max_replicas: int = Field( + description="The maximum number of replicas to scale up to under load" + ) + + +class EndpointPricing(TogetherJSONModel): + """Pricing details for using an endpoint.""" + + cents_per_minute: float = Field( + description="Cost per minute of endpoint uptime in cents" + ) + + +class HardwareSpec(TogetherJSONModel): + """Detailed specifications of a hardware configuration.""" + + gpu_type: str = Field(description="The type/model of GPU") + gpu_link: str = Field(description="The GPU interconnect technology") + gpu_memory: Union[float, int] = Field(description="Amount of GPU memory in GB") + gpu_count: int = Field(description="Number of GPUs in this configuration") + + +class HardwareAvailability(TogetherJSONModel): + """Indicates the current availability status of a hardware configuration.""" + + status: Literal["available", "unavailable", "insufficient"] = Field( + description="The availability status of the hardware configuration" + ) + + +class HardwareWithStatus(TogetherJSONModel): + """Hardware configuration details with optional availability status.""" + + object: Literal["hardware"] = Field(description="The type of object") + id: str = Field(description="Unique identifier for the hardware configuration") + pricing: EndpointPricing = Field( + description="Pricing details for this hardware configuration" + ) + specs: HardwareSpec = Field(description="Detailed specifications of this hardware") + availability: Optional[HardwareAvailability] = Field( + default=None, + description="Current availability status of this hardware configuration", + ) + updated_at: datetime = Field( + description="Timestamp of when the hardware status was last updated" + ) + + +class BaseEndpoint(TogetherJSONModel): + """Base class for endpoint models with common fields.""" + + object: Literal["endpoint"] = Field(description="The type of object") + id: Optional[str] = Field( + default=None, description="Unique identifier for the endpoint" + ) + name: str = Field(description="System name for the endpoint") + model: str = Field(description="The model deployed on this endpoint") + type: str = Field(description="The type of endpoint") + owner: str = Field(description="The owner of this endpoint") + state: Literal["PENDING", "STARTING", "STARTED", "STOPPING", "STOPPED", "ERROR"] = ( + Field(description="Current state of the endpoint") + ) + created_at: datetime = Field(description="Timestamp when the endpoint was created") + + +class ListEndpoint(BaseEndpoint): + """Details about an endpoint when listed via the list endpoint.""" + + type: Literal["dedicated", "serverless"] = Field(description="The type of endpoint") + + +class DedicatedEndpoint(BaseEndpoint): + """Details about a dedicated endpoint deployment.""" + + id: str = Field(description="Unique identifier for the endpoint") + type: Literal["dedicated"] = Field(description="The type of endpoint") + display_name: str = Field(description="Human-readable name for the endpoint") + hardware: str = Field( + description="The hardware configuration used for this endpoint" + ) + autoscaling: Autoscaling = Field( + description="Configuration for automatic scaling of the endpoint" + ) + + +__all__ = [ + "DedicatedEndpoint", + "ListEndpoint", + "Autoscaling", + "EndpointPricing", + "HardwareSpec", + "HardwareAvailability", + "HardwareWithStatus", +]