Skip to content

Commit cb2f7ea

Browse files
authored
Support auto shutdown and model list (#266)
* feat: support auto-shutdown * feat: add support for --type dedicated * fix: make list readable by programs * add --json output * fix description
1 parent 74593cb commit cb2f7ea

File tree

4 files changed

+131
-20
lines changed

4 files changed

+131
-20
lines changed

src/together/cli/api/endpoints.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ def endpoints(ctx: click.Context) -> None:
127127
is_flag=True,
128128
help="Create the endpoint in STOPPED state instead of auto-starting it",
129129
)
130+
@click.option(
131+
"--inactive-timeout",
132+
type=int,
133+
help="Number of minutes of inactivity after which the endpoint will be automatically stopped. Set to 0 to disable.",
134+
)
130135
@click.option(
131136
"--wait",
132137
is_flag=True,
@@ -146,6 +151,7 @@ def create(
146151
no_prompt_cache: bool,
147152
no_speculative_decoding: bool,
148153
no_auto_start: bool,
154+
inactive_timeout: int | None,
149155
wait: bool,
150156
) -> None:
151157
"""Create a new dedicated inference endpoint."""
@@ -170,6 +176,7 @@ def create(
170176
disable_prompt_cache=no_prompt_cache,
171177
disable_speculative_decoding=no_speculative_decoding,
172178
state="STOPPED" if no_auto_start else "STARTED",
179+
inactive_timeout=inactive_timeout,
173180
)
174181
except InvalidRequestError as e:
175182
print_api_error(e)
@@ -194,6 +201,8 @@ def create(
194201
click.echo(" Speculative decoding: disabled", err=True)
195202
if no_auto_start:
196203
click.echo(" Auto-start: disabled", err=True)
204+
if inactive_timeout is not None:
205+
click.echo(f" Inactive timeout: {inactive_timeout} minutes", err=True)
197206

198207
click.echo(f"Endpoint created successfully, id: {response.id}", err=True)
199208

@@ -371,6 +380,11 @@ def list(
371380
type=int,
372381
help="New maximum number of replicas to scale up to",
373382
)
383+
@click.option(
384+
"--inactive-timeout",
385+
type=int,
386+
help="Number of minutes of inactivity after which the endpoint will be automatically stopped. Set to 0 to disable.",
387+
)
374388
@click.pass_obj
375389
@handle_api_errors
376390
def update(
@@ -379,9 +393,10 @@ def update(
379393
display_name: str | None,
380394
min_replicas: int | None,
381395
max_replicas: int | None,
396+
inactive_timeout: int | None,
382397
) -> None:
383398
"""Update a dedicated inference endpoint's configuration."""
384-
if not any([display_name, min_replicas, max_replicas]):
399+
if not any([display_name, min_replicas, max_replicas, inactive_timeout]):
385400
click.echo("Error: At least one update option must be specified", err=True)
386401
sys.exit(1)
387402

@@ -400,6 +415,8 @@ def update(
400415
if min_replicas is not None and max_replicas is not None:
401416
kwargs["min_replicas"] = min_replicas
402417
kwargs["max_replicas"] = max_replicas
418+
if inactive_timeout is not None:
419+
kwargs["inactive_timeout"] = inactive_timeout
403420

404421
_response = client.endpoints.update(endpoint_id, **kwargs)
405422

@@ -410,6 +427,8 @@ def update(
410427
if min_replicas is not None and max_replicas is not None:
411428
click.echo(f" Min replicas: {min_replicas}", err=True)
412429
click.echo(f" Max replicas: {max_replicas}", err=True)
430+
if inactive_timeout is not None:
431+
click.echo(f" Inactive timeout: {inactive_timeout} minutes", err=True)
413432

414433
click.echo("Successfully updated endpoint", err=True)
415434
click.echo(endpoint_id)

src/together/cli/api/models.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from textwrap import wrap
1+
import json as json_lib
22

33
import click
44
from tabulate import tabulate
@@ -15,28 +15,41 @@ def models(ctx: click.Context) -> None:
1515

1616

1717
@models.command()
18+
@click.option(
19+
"--type",
20+
type=click.Choice(["dedicated"]),
21+
help="Filter models by type (dedicated: models that can be deployed as dedicated endpoints)",
22+
)
23+
@click.option(
24+
"--json",
25+
is_flag=True,
26+
help="Output in JSON format",
27+
)
1828
@click.pass_context
19-
def list(ctx: click.Context) -> None:
29+
def list(ctx: click.Context, type: str | None, json: bool) -> None:
2030
"""List models"""
2131
client: Together = ctx.obj
2232

23-
response = client.models.list()
33+
response = client.models.list(dedicated=(type == "dedicated"))
2434

2535
display_list = []
2636

2737
model: ModelObject
2838
for model in response:
2939
display_list.append(
3040
{
31-
"ID": "\n".join(wrap(model.id or "", width=30)),
32-
"Name": "\n".join(wrap(model.display_name or "", width=30)),
41+
"ID": model.id,
42+
"Name": model.display_name,
3343
"Organization": model.organization,
3444
"Type": model.type,
3545
"Context Length": model.context_length,
36-
"License": "\n".join(wrap(model.license or "", width=30)),
46+
"License": model.license,
3747
"Input per 1M token": model.pricing.input,
3848
"Output per 1M token": model.pricing.output,
3949
}
4050
)
4151

42-
click.echo(tabulate(display_list, headers="keys", tablefmt="grid"))
52+
if json:
53+
click.echo(json_lib.dumps(display_list, indent=2))
54+
else:
55+
click.echo(tabulate(display_list, headers="keys", tablefmt="plain"))

src/together/resources/endpoints.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def create(
5959
disable_prompt_cache: bool = False,
6060
disable_speculative_decoding: bool = False,
6161
state: Literal["STARTED", "STOPPED"] = "STARTED",
62+
inactive_timeout: Optional[int] = None,
6263
) -> DedicatedEndpoint:
6364
"""
6465
Create a new dedicated endpoint.
@@ -72,6 +73,7 @@ def create(
7273
disable_prompt_cache (bool, optional): Whether to disable the prompt cache. Defaults to False.
7374
disable_speculative_decoding (bool, optional): Whether to disable speculative decoding. Defaults to False.
7475
state (str, optional): The desired state of the endpoint. Defaults to "STARTED".
76+
inactive_timeout (int, optional): The number of minutes of inactivity after which the endpoint will be automatically stopped. Set to 0 to disable automatic timeout.
7577
7678
Returns:
7779
DedicatedEndpoint: Object containing endpoint information
@@ -80,7 +82,7 @@ def create(
8082
client=self._client,
8183
)
8284

83-
data: Dict[str, Union[str, bool, Dict[str, int]]] = {
85+
data: Dict[str, Union[str, bool, Dict[str, int], int]] = {
8486
"model": model,
8587
"hardware": hardware,
8688
"autoscaling": {
@@ -95,6 +97,9 @@ def create(
9597
if display_name is not None:
9698
data["display_name"] = display_name
9799

100+
if inactive_timeout is not None:
101+
data["inactive_timeout"] = inactive_timeout
102+
98103
response, _, _ = requestor.request(
99104
options=TogetherRequest(
100105
method="POST",
@@ -161,6 +166,7 @@ def update(
161166
max_replicas: Optional[int] = None,
162167
state: Optional[Literal["STARTED", "STOPPED"]] = None,
163168
display_name: Optional[str] = None,
169+
inactive_timeout: Optional[int] = None,
164170
) -> DedicatedEndpoint:
165171
"""
166172
Update an endpoint's configuration.
@@ -171,6 +177,7 @@ def update(
171177
max_replicas (int, optional): The maximum number of replicas to scale up to
172178
state (str, optional): The desired state of the endpoint ("STARTED" or "STOPPED")
173179
display_name (str, optional): A human-readable name for the endpoint
180+
inactive_timeout (int, optional): The number of minutes of inactivity after which the endpoint will be automatically stopped. Set to 0 to disable automatic timeout.
174181
175182
Returns:
176183
DedicatedEndpoint: Object containing endpoint information
@@ -179,7 +186,7 @@ def update(
179186
client=self._client,
180187
)
181188

182-
data: Dict[str, Union[str, Dict[str, int]]] = {}
189+
data: Dict[str, Union[str, Dict[str, int], int]] = {}
183190

184191
if min_replicas is not None or max_replicas is not None:
185192
current_min = min_replicas
@@ -200,6 +207,9 @@ def update(
200207
if display_name is not None:
201208
data["display_name"] = display_name
202209

210+
if inactive_timeout is not None:
211+
data["inactive_timeout"] = inactive_timeout
212+
203213
response, _, _ = requestor.request(
204214
options=TogetherRequest(
205215
method="PATCH",
@@ -297,6 +307,7 @@ async def create(
297307
disable_prompt_cache: bool = False,
298308
disable_speculative_decoding: bool = False,
299309
state: Literal["STARTED", "STOPPED"] = "STARTED",
310+
inactive_timeout: Optional[int] = None,
300311
) -> DedicatedEndpoint:
301312
"""
302313
Create a new dedicated endpoint.
@@ -310,6 +321,7 @@ async def create(
310321
disable_prompt_cache (bool, optional): Whether to disable the prompt cache. Defaults to False.
311322
disable_speculative_decoding (bool, optional): Whether to disable speculative decoding. Defaults to False.
312323
state (str, optional): The desired state of the endpoint. Defaults to "STARTED".
324+
inactive_timeout (int, optional): The number of minutes of inactivity after which the endpoint will be automatically stopped. Set to 0 to disable automatic timeout.
313325
314326
Returns:
315327
DedicatedEndpoint: Object containing endpoint information
@@ -318,7 +330,7 @@ async def create(
318330
client=self._client,
319331
)
320332

321-
data: Dict[str, Union[str, bool, Dict[str, int]]] = {
333+
data: Dict[str, Union[str, bool, Dict[str, int], int]] = {
322334
"model": model,
323335
"hardware": hardware,
324336
"autoscaling": {
@@ -333,6 +345,9 @@ async def create(
333345
if display_name is not None:
334346
data["display_name"] = display_name
335347

348+
if inactive_timeout is not None:
349+
data["inactive_timeout"] = inactive_timeout
350+
336351
response, _, _ = await requestor.arequest(
337352
options=TogetherRequest(
338353
method="POST",
@@ -399,6 +414,7 @@ async def update(
399414
max_replicas: Optional[int] = None,
400415
state: Optional[Literal["STARTED", "STOPPED"]] = None,
401416
display_name: Optional[str] = None,
417+
inactive_timeout: Optional[int] = None,
402418
) -> DedicatedEndpoint:
403419
"""
404420
Update an endpoint's configuration.
@@ -409,6 +425,7 @@ async def update(
409425
max_replicas (int, optional): The maximum number of replicas to scale up to
410426
state (str, optional): The desired state of the endpoint ("STARTED" or "STOPPED")
411427
display_name (str, optional): A human-readable name for the endpoint
428+
inactive_timeout (int, optional): The number of minutes of inactivity after which the endpoint will be automatically stopped. Set to 0 to disable automatic timeout.
412429
413430
Returns:
414431
DedicatedEndpoint: Object containing endpoint information
@@ -417,7 +434,7 @@ async def update(
417434
client=self._client,
418435
)
419436

420-
data: Dict[str, Union[str, Dict[str, int]]] = {}
437+
data: Dict[str, Union[str, Dict[str, int], int]] = {}
421438

422439
if min_replicas is not None or max_replicas is not None:
423440
current_min = min_replicas
@@ -438,6 +455,9 @@ async def update(
438455
if display_name is not None:
439456
data["display_name"] = display_name
440457

458+
if inactive_timeout is not None:
459+
data["inactive_timeout"] = inactive_timeout
460+
441461
response, _, _ = await requestor.arequest(
442462
options=TogetherRequest(
443463
method="PATCH",

src/together/resources/models.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,47 @@
1111
)
1212

1313

14-
class Models:
14+
class ModelsBase:
1515
def __init__(self, client: TogetherClient) -> None:
1616
self._client = client
1717

18+
def _filter_dedicated_models(
19+
self, models: List[ModelObject], dedicated_response: TogetherResponse
20+
) -> List[ModelObject]:
21+
"""
22+
Filter models based on dedicated model response.
23+
24+
Args:
25+
models (List[ModelObject]): List of all models
26+
dedicated_response (TogetherResponse): Response from autoscale models endpoint
27+
28+
Returns:
29+
List[ModelObject]: Filtered list of models
30+
"""
31+
assert isinstance(dedicated_response.data, list)
32+
33+
# Create a set of dedicated model names for efficient lookup
34+
dedicated_model_names = {model["name"] for model in dedicated_response.data}
35+
36+
# Filter models to only include those in dedicated_model_names
37+
# Note: The model.id from ModelObject matches the name field in the autoscale response
38+
return [model for model in models if model.id in dedicated_model_names]
39+
40+
41+
class Models(ModelsBase):
1842
def list(
1943
self,
44+
dedicated: bool = False,
2045
) -> List[ModelObject]:
2146
"""
2247
Method to return list of models on the API
2348
49+
Args:
50+
dedicated (bool, optional): If True, returns only dedicated models. Defaults to False.
51+
2452
Returns:
2553
List[ModelObject]: List of model objects
2654
"""
27-
2855
requestor = api_requestor.APIRequestor(
2956
client=self._client,
3057
)
@@ -40,23 +67,39 @@ def list(
4067
assert isinstance(response, TogetherResponse)
4168
assert isinstance(response.data, list)
4269

43-
return [ModelObject(**model) for model in response.data]
70+
models = [ModelObject(**model) for model in response.data]
4471

72+
if dedicated:
73+
# Get dedicated models
74+
dedicated_response, _, _ = requestor.request(
75+
options=TogetherRequest(
76+
method="GET",
77+
url="autoscale/models",
78+
),
79+
stream=False,
80+
)
4581

46-
class AsyncModels:
47-
def __init__(self, client: TogetherClient) -> None:
48-
self._client = client
82+
models = self._filter_dedicated_models(models, dedicated_response)
83+
84+
models.sort(key=lambda x: x.id.lower())
4985

86+
return models
87+
88+
89+
class AsyncModels(ModelsBase):
5090
async def list(
5191
self,
92+
dedicated: bool = False,
5293
) -> List[ModelObject]:
5394
"""
5495
Async method to return list of models on API
5596
97+
Args:
98+
dedicated (bool, optional): If True, returns only dedicated models. Defaults to False.
99+
56100
Returns:
57101
List[ModelObject]: List of model objects
58102
"""
59-
60103
requestor = api_requestor.APIRequestor(
61104
client=self._client,
62105
)
@@ -72,4 +115,20 @@ async def list(
72115
assert isinstance(response, TogetherResponse)
73116
assert isinstance(response.data, list)
74117

75-
return [ModelObject(**model) for model in response.data]
118+
models = [ModelObject(**model) for model in response.data]
119+
120+
if dedicated:
121+
# Get dedicated models
122+
dedicated_response, _, _ = await requestor.arequest(
123+
options=TogetherRequest(
124+
method="GET",
125+
url="autoscale/models",
126+
),
127+
stream=False,
128+
)
129+
130+
models = self._filter_dedicated_models(models, dedicated_response)
131+
132+
models.sort(key=lambda x: x.id.lower())
133+
134+
return models

0 commit comments

Comments
 (0)