Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/together/cli/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ def endpoints(ctx: click.Context) -> None:
is_flag=True,
help="Create the endpoint in STOPPED state instead of auto-starting it",
)
@click.option(
"--inactive-timeout",
type=int,
help="Number of minutes of inactivity after which the endpoint will be automatically stopped. Set to 0 to disable.",
)
@click.option(
"--wait",
is_flag=True,
Expand All @@ -146,6 +151,7 @@ def create(
no_prompt_cache: bool,
no_speculative_decoding: bool,
no_auto_start: bool,
inactive_timeout: int | None,
wait: bool,
) -> None:
"""Create a new dedicated inference endpoint."""
Expand All @@ -170,6 +176,7 @@ def create(
disable_prompt_cache=no_prompt_cache,
disable_speculative_decoding=no_speculative_decoding,
state="STOPPED" if no_auto_start else "STARTED",
inactive_timeout=inactive_timeout,
)
except InvalidRequestError as e:
print_api_error(e)
Expand All @@ -194,6 +201,8 @@ def create(
click.echo(" Speculative decoding: disabled", err=True)
if no_auto_start:
click.echo(" Auto-start: disabled", err=True)
if inactive_timeout is not None:
click.echo(f" Inactive timeout: {inactive_timeout} minutes", err=True)

click.echo(f"Endpoint created successfully, id: {response.id}", err=True)

Expand Down Expand Up @@ -371,6 +380,11 @@ def list(
type=int,
help="New maximum number of replicas to scale up to",
)
@click.option(
"--inactive-timeout",
type=int,
help="Number of minutes of inactivity after which the endpoint will be automatically stopped. Set to 0 to disable.",
)
@click.pass_obj
@handle_api_errors
def update(
Expand All @@ -379,9 +393,10 @@ def update(
display_name: str | None,
min_replicas: int | None,
max_replicas: int | None,
inactive_timeout: int | None,
) -> None:
"""Update a dedicated inference endpoint's configuration."""
if not any([display_name, min_replicas, max_replicas]):
if not any([display_name, min_replicas, max_replicas, inactive_timeout]):
click.echo("Error: At least one update option must be specified", err=True)
sys.exit(1)

Expand All @@ -400,6 +415,8 @@ def update(
if min_replicas is not None and max_replicas is not None:
kwargs["min_replicas"] = min_replicas
kwargs["max_replicas"] = max_replicas
if inactive_timeout is not None:
kwargs["inactive_timeout"] = inactive_timeout

_response = client.endpoints.update(endpoint_id, **kwargs)

Expand All @@ -410,6 +427,8 @@ def update(
if min_replicas is not None and max_replicas is not None:
click.echo(f" Min replicas: {min_replicas}", err=True)
click.echo(f" Max replicas: {max_replicas}", err=True)
if inactive_timeout is not None:
click.echo(f" Inactive timeout: {inactive_timeout} minutes", err=True)

click.echo("Successfully updated endpoint", err=True)
click.echo(endpoint_id)
27 changes: 20 additions & 7 deletions src/together/cli/api/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from textwrap import wrap
import json as json_lib

import click
from tabulate import tabulate
Expand All @@ -15,28 +15,41 @@ def models(ctx: click.Context) -> None:


@models.command()
@click.option(
"--type",
type=click.Choice(["dedicated"]),
help="Filter models by type (dedicated: models that support autoscaling)",
)
@click.option(
"--json",
is_flag=True,
help="Output in JSON format",
)
@click.pass_context
def list(ctx: click.Context) -> None:
def list(ctx: click.Context, type: str | None, json: bool) -> None:
"""List models"""
client: Together = ctx.obj

response = client.models.list()
response = client.models.list(dedicated=(type == "dedicated"))

display_list = []

model: ModelObject
for model in response:
display_list.append(
{
"ID": "\n".join(wrap(model.id or "", width=30)),
"Name": "\n".join(wrap(model.display_name or "", width=30)),
"ID": model.id,
"Name": model.display_name,
"Organization": model.organization,
"Type": model.type,
"Context Length": model.context_length,
"License": "\n".join(wrap(model.license or "", width=30)),
"License": model.license,
"Input per 1M token": model.pricing.input,
"Output per 1M token": model.pricing.output,
}
)

click.echo(tabulate(display_list, headers="keys", tablefmt="grid"))
if json:
click.echo(json_lib.dumps(display_list, indent=2))
else:
click.echo(tabulate(display_list, headers="keys", tablefmt="plain"))
28 changes: 24 additions & 4 deletions src/together/resources/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def create(
disable_prompt_cache: bool = False,
disable_speculative_decoding: bool = False,
state: Literal["STARTED", "STOPPED"] = "STARTED",
inactive_timeout: Optional[int] = None,
) -> DedicatedEndpoint:
"""
Create a new dedicated endpoint.
Expand All @@ -72,6 +73,7 @@ def create(
disable_prompt_cache (bool, optional): Whether to disable the prompt cache. Defaults to False.
disable_speculative_decoding (bool, optional): Whether to disable speculative decoding. Defaults to False.
state (str, optional): The desired state of the endpoint. Defaults to "STARTED".
inactive_timeout (int, optional): The number of minutes of inactivity after which the endpoint will be automatically stopped. Set to 0 to disable automatic timeout.

Returns:
DedicatedEndpoint: Object containing endpoint information
Expand All @@ -80,7 +82,7 @@ def create(
client=self._client,
)

data: Dict[str, Union[str, bool, Dict[str, int]]] = {
data: Dict[str, Union[str, bool, Dict[str, int], int]] = {
"model": model,
"hardware": hardware,
"autoscaling": {
Expand All @@ -95,6 +97,9 @@ def create(
if display_name is not None:
data["display_name"] = display_name

if inactive_timeout is not None:
data["inactive_timeout"] = inactive_timeout

response, _, _ = requestor.request(
options=TogetherRequest(
method="POST",
Expand Down Expand Up @@ -161,6 +166,7 @@ def update(
max_replicas: Optional[int] = None,
state: Optional[Literal["STARTED", "STOPPED"]] = None,
display_name: Optional[str] = None,
inactive_timeout: Optional[int] = None,
) -> DedicatedEndpoint:
"""
Update an endpoint's configuration.
Expand All @@ -171,6 +177,7 @@ def update(
max_replicas (int, optional): The maximum number of replicas to scale up to
state (str, optional): The desired state of the endpoint ("STARTED" or "STOPPED")
display_name (str, optional): A human-readable name for the endpoint
inactive_timeout (int, optional): The number of minutes of inactivity after which the endpoint will be automatically stopped. Set to 0 to disable automatic timeout.

Returns:
DedicatedEndpoint: Object containing endpoint information
Expand All @@ -179,7 +186,7 @@ def update(
client=self._client,
)

data: Dict[str, Union[str, Dict[str, int]]] = {}
data: Dict[str, Union[str, Dict[str, int], int]] = {}

if min_replicas is not None or max_replicas is not None:
current_min = min_replicas
Expand All @@ -200,6 +207,9 @@ def update(
if display_name is not None:
data["display_name"] = display_name

if inactive_timeout is not None:
data["inactive_timeout"] = inactive_timeout

response, _, _ = requestor.request(
options=TogetherRequest(
method="PATCH",
Expand Down Expand Up @@ -297,6 +307,7 @@ async def create(
disable_prompt_cache: bool = False,
disable_speculative_decoding: bool = False,
state: Literal["STARTED", "STOPPED"] = "STARTED",
inactive_timeout: Optional[int] = None,
) -> DedicatedEndpoint:
"""
Create a new dedicated endpoint.
Expand All @@ -310,6 +321,7 @@ async def create(
disable_prompt_cache (bool, optional): Whether to disable the prompt cache. Defaults to False.
disable_speculative_decoding (bool, optional): Whether to disable speculative decoding. Defaults to False.
state (str, optional): The desired state of the endpoint. Defaults to "STARTED".
inactive_timeout (int, optional): The number of minutes of inactivity after which the endpoint will be automatically stopped. Set to 0 to disable automatic timeout.

Returns:
DedicatedEndpoint: Object containing endpoint information
Expand All @@ -318,7 +330,7 @@ async def create(
client=self._client,
)

data: Dict[str, Union[str, bool, Dict[str, int]]] = {
data: Dict[str, Union[str, bool, Dict[str, int], int]] = {
"model": model,
"hardware": hardware,
"autoscaling": {
Expand All @@ -333,6 +345,9 @@ async def create(
if display_name is not None:
data["display_name"] = display_name

if inactive_timeout is not None:
data["inactive_timeout"] = inactive_timeout

response, _, _ = await requestor.arequest(
options=TogetherRequest(
method="POST",
Expand Down Expand Up @@ -399,6 +414,7 @@ async def update(
max_replicas: Optional[int] = None,
state: Optional[Literal["STARTED", "STOPPED"]] = None,
display_name: Optional[str] = None,
inactive_timeout: Optional[int] = None,
) -> DedicatedEndpoint:
"""
Update an endpoint's configuration.
Expand All @@ -409,6 +425,7 @@ async def update(
max_replicas (int, optional): The maximum number of replicas to scale up to
state (str, optional): The desired state of the endpoint ("STARTED" or "STOPPED")
display_name (str, optional): A human-readable name for the endpoint
inactive_timeout (int, optional): The number of minutes of inactivity after which the endpoint will be automatically stopped. Set to 0 to disable automatic timeout.

Returns:
DedicatedEndpoint: Object containing endpoint information
Expand All @@ -417,7 +434,7 @@ async def update(
client=self._client,
)

data: Dict[str, Union[str, Dict[str, int]]] = {}
data: Dict[str, Union[str, Dict[str, int], int]] = {}

if min_replicas is not None or max_replicas is not None:
current_min = min_replicas
Expand All @@ -438,6 +455,9 @@ async def update(
if display_name is not None:
data["display_name"] = display_name

if inactive_timeout is not None:
data["inactive_timeout"] = inactive_timeout

response, _, _ = await requestor.arequest(
options=TogetherRequest(
method="PATCH",
Expand Down
75 changes: 67 additions & 8 deletions src/together/resources/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,47 @@
)


class Models:
class ModelsBase:
def __init__(self, client: TogetherClient) -> None:
self._client = client

def _filter_dedicated_models(
self, models: List[ModelObject], dedicated_response: TogetherResponse
) -> List[ModelObject]:
"""
Filter models based on dedicated model response.

Args:
models (List[ModelObject]): List of all models
dedicated_response (TogetherResponse): Response from autoscale models endpoint

Returns:
List[ModelObject]: Filtered list of models
"""
assert isinstance(dedicated_response.data, list)

# Create a set of dedicated model names for efficient lookup
dedicated_model_names = {model["name"] for model in dedicated_response.data}

# Filter models to only include those in dedicated_model_names
# Note: The model.id from ModelObject matches the name field in the autoscale response
return [model for model in models if model.id in dedicated_model_names]


class Models(ModelsBase):
def list(
self,
dedicated: bool = False,
) -> List[ModelObject]:
"""
Method to return list of models on the API

Args:
dedicated (bool, optional): If True, returns only dedicated models. Defaults to False.

Returns:
List[ModelObject]: List of model objects
"""

requestor = api_requestor.APIRequestor(
client=self._client,
)
Expand All @@ -40,23 +67,39 @@ def list(
assert isinstance(response, TogetherResponse)
assert isinstance(response.data, list)

return [ModelObject(**model) for model in response.data]
models = [ModelObject(**model) for model in response.data]

if dedicated:
# Get dedicated models
dedicated_response, _, _ = requestor.request(
options=TogetherRequest(
method="GET",
url="autoscale/models",
),
stream=False,
)

class AsyncModels:
def __init__(self, client: TogetherClient) -> None:
self._client = client
models = self._filter_dedicated_models(models, dedicated_response)

models.sort(key=lambda x: x.id.lower())

return models


class AsyncModels(ModelsBase):
async def list(
self,
dedicated: bool = False,
) -> List[ModelObject]:
"""
Async method to return list of models on API

Args:
dedicated (bool, optional): If True, returns only dedicated models. Defaults to False.

Returns:
List[ModelObject]: List of model objects
"""

requestor = api_requestor.APIRequestor(
client=self._client,
)
Expand All @@ -72,4 +115,20 @@ async def list(
assert isinstance(response, TogetherResponse)
assert isinstance(response.data, list)

return [ModelObject(**model) for model in response.data]
models = [ModelObject(**model) for model in response.data]

if dedicated:
# Get dedicated models
dedicated_response, _, _ = await requestor.arequest(
options=TogetherRequest(
method="GET",
url="autoscale/models",
),
stream=False,
)

models = self._filter_dedicated_models(models, dedicated_response)

models.sort(key=lambda x: x.id.lower())

return models