Skip to content

Commit 3a04605

Browse files
committed
adding preprompt concept
1 parent b7a6d8d commit 3a04605

File tree

7 files changed

+123
-58
lines changed

7 files changed

+123
-58
lines changed

src/together/lib/cli/__init__.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
import httpx
1010
from cyclopts import App, MissingArgumentError, Parameter
1111

12+
from together.lib.cli.logger.config import CLIConfig
1213
from together import AsyncTogether
1314
from together._exceptions import APIError
1415
from together._version import __version__
1516
from together._utils._logs import setup_logging
16-
from together.lib.cli.logger.prompt import PromptParameter
17+
from together.lib.cli.logger.prompt import PromptParameter, console
1718

1819
app = App(
1920
name="together",
@@ -25,16 +26,6 @@
2526
app['--version'].group = "Parameters"
2627
app['--help'].group = "Parameters"
2728

28-
class Config:
29-
client: AsyncTogether
30-
non_interactive: bool
31-
json: bool
32-
33-
def __init__(self, client: AsyncTogether, non_interactive: bool, json: bool):
34-
self.client = client
35-
self.non_interactive = non_interactive
36-
self.json = json
37-
3829
def _create_client(
3930
api_key: Optional[str],
4031
base_url: Optional[str],
@@ -86,7 +77,7 @@ async def _launcher(
8677
os.environ.setdefault("TOGETHER_LOG", "debug")
8778
setup_logging()
8879
client = _create_client(api_key, base_url, timeout, max_retries)
89-
config = Config(
80+
config = CLIConfig(
9081
client=client,
9182
non_interactive=non_interactive or False,
9283
json=json or False,
@@ -104,7 +95,6 @@ async def run_command():
10495
remaining.append(value)
10596

10697
kwargs = dict(bound.kwargs)
107-
kwargs["config"] = config
10898
if "config" in extra:
10999
kwargs["config"] = config
110100
result = command(*bound.args, **kwargs)
@@ -130,12 +120,17 @@ async def run_command():
130120

131121
value: str | None = None
132122
if prompt is not None:
123+
await prompt.preprompt(config)
133124
value = await prompt.prompt(e.argument.name)
134125
print("") # Push a blank line for nicer output
135126
remaining.append(e.argument.name)
136127
remaining.append(value)
137128
await run_command()
138-
except (KeyboardInterrupt, SystemExit):
129+
else:
130+
# TODO: Better design this
131+
print("Missing required argument", e.argument.name)
132+
sys.exit(1)
133+
except KeyboardInterrupt:
139134
pass
140135
except APIError as e:
141136
error_msg = ""
@@ -146,12 +141,16 @@ async def run_command():
146141
print(f"Failed", file=sys.stderr)
147142
print(f"{error_msg}", file=sys.stderr)
148143
sys.exit(1)
149-
except Exception as e:
150-
print(f"Failed", file=sys.stderr)
151-
print(f"An unexpected error occurred - {e!s}", file=sys.stderr)
152-
sys.exit(1)
144+
# except Exception as e:
145+
# print(f"Failed", file=sys.stderr)
146+
# print(f"An unexpected error occurred - {e!s}", file=sys.stderr)
147+
# sys.exit(1)
153148
try:
154149
await run_command()
150+
# except Exception as e:
151+
# print(f"Failed", file=sys.stderr)
152+
# print(f"An unexpected error occurred - {e!s}", file=sys.stderr)
153+
# sys.exit(1)
155154
finally:
156155
await client.close()
157156

src/together/lib/cli/api/endpoints/create.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,12 @@
66

77
from cyclopts import Parameter
88

9-
from together import APIError, AsyncTogether, omit
9+
from together import APIError, omit
10+
from together.lib.cli.logger.config import CLIConfig
1011

11-
from together.lib.cli.api.endpoints._utils import handle_endpoint_api_errors
1212

1313
from .hardware import hardware as list_hardware
1414

15-
16-
@handle_endpoint_api_errors("Endpoints")
1715
async def create(
1816
model: str,
1917
min_replicas: int = 1,
@@ -26,9 +24,8 @@ async def create(
2624
inactive_timeout: Optional[int] = None,
2725
availability_zone: Optional[str] = None,
2826
wait: bool = False,
29-
json_output: bool = False,
3027
*,
31-
client: Annotated[AsyncTogether, Parameter(parse=False)],
28+
config: Annotated[CLIConfig, Parameter(parse=False)],
3229
) -> None:
3330
"""Create a new dedicated inference endpoint."""
3431
if min_replicas > max_replicas:
@@ -40,7 +37,7 @@ async def create(
4037

4138
if availability_zone:
4239
try:
43-
valid_zones = await client.endpoints.list_avzones()
40+
valid_zones = await config.client.endpoints.list_avzones()
4441
if availability_zone not in valid_zones.avzones:
4542
print(f"Error: Invalid availability zone '{availability_zone}'", file=sys.stderr)
4643
if valid_zones.avzones:
@@ -51,19 +48,19 @@ async def create(
5148
except Exception:
5249
pass
5350

54-
if json_output and wait:
51+
if config.json and wait:
5552
print("Error: --json and --wait cannot be used together.", file=sys.stderr)
5653
return
5754

58-
if no_prompt_cache is not None and not json_output:
55+
if no_prompt_cache is not None and not config.json:
5956
print("Warning: --no-prompt-cache is deprecated and no longer has any effect.", file=sys.stderr)
6057

6158
if hardware is None:
6259
print("Error: --hardware is required", file=sys.stderr)
6360
sys.exit(1)
6461

6562
try:
66-
response = await client.endpoints.create(
63+
response = await config.client.endpoints.create(
6764
model=model,
6865
hardware=hardware,
6966
autoscaling={"min_replicas": min_replicas, "max_replicas": max_replicas},
@@ -74,7 +71,7 @@ async def create(
7471
extra_query={"availability_zone": availability_zone or omit},
7572
)
7673
except APIError as e:
77-
if json_output:
74+
if config.json:
7875
raise e
7976
error_msg = str(e.args[0]).lower() if e.args else ""
8077
if (
@@ -85,7 +82,7 @@ async def create(
8582
):
8683
print("Invalid hardware selected.", file=sys.stderr)
8784
print("\nAvailable hardware options:", file=sys.stderr)
88-
await list_hardware(model=model, json_output=False, available=True, client=client)
85+
await list_hardware(model=model, json_output=False, available=True, client=config.client)
8986
sys.exit(1)
9087
elif "model" in error_msg and (
9188
"not found" in error_msg
@@ -99,7 +96,7 @@ async def create(
9996
sys.exit(1)
10097
raise e
10198

102-
if json_output:
99+
if config.json:
103100
print(response.model_dump_json(indent=2))
104101
return
105102

@@ -122,7 +119,7 @@ async def create(
122119

123120
if wait:
124121
print("Waiting for endpoint to be ready...", file=sys.stderr)
125-
while (await client.endpoints.retrieve(response.id)).state != "STARTED":
122+
while (await config.client.endpoints.retrieve(response.id)).state != "STARTED":
126123
await asyncio.sleep(1)
127124
print("Endpoint ready", file=sys.stderr)
128125
print(response.id)

src/together/lib/cli/api/endpoints/start.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,54 @@
33
import json as json_lib
44
import sys
55
from typing import Annotated
6+
from typing_extensions import override
67

78
from cyclopts import Parameter
89

910
import asyncio
11+
import questionary
1012

11-
from together import AsyncTogether
12-
13-
from together.lib.cli.api.endpoints._utils import handle_endpoint_api_errors
13+
from together.lib.cli.logger.config import CLIConfig
14+
from together.lib.cli.logger.prompt import PromptParameter, console
1415
from together.lib.utils.serializer import datetime_serializer
1516

17+
class LoadEndpointsPrompt(PromptParameter):
18+
@override
19+
async def preprompt(self, config: CLIConfig):
20+
with console.status(
21+
"[progress.description]Loading endpoints...[/progress.description]",
22+
spinner="dots",
23+
spinner_style="bar.pulse",
24+
):
25+
endpoints = await config.client.endpoints.list()
26+
self.choices = []
27+
for endpoint in endpoints.data:
28+
# This shouldn't happen.. but does happen sometimes...
29+
if endpoint.id is None: # type: ignore
30+
continue
31+
32+
if endpoint.state != "STARTED":
33+
self.choices.append(questionary.Choice(title=[("", endpoint.name), ("class:disabled", " ({})".format(endpoint.id))], value=endpoint.id))
34+
1635

17-
@handle_endpoint_api_errors("Endpoints")
1836
async def start(
19-
endpoint_id: str,
37+
endpoint_id: Annotated[str, Parameter(required=True, help="The ID of the endpoint to start"), LoadEndpointsPrompt(message="Enter the endpoint ID")],
2038
wait: bool = False,
21-
json_output: bool = False,
2239
*,
23-
client: Annotated[AsyncTogether, Parameter(parse=False)],
40+
config: Annotated[CLIConfig, Parameter(parse=False)],
2441
) -> None:
2542
"""Start a dedicated inference endpoint."""
26-
response = await client.endpoints.update(endpoint_id, state="STARTED")
43+
response = await config.client.endpoints.update(endpoint_id, state="STARTED")
44+
2745

28-
if json_output:
46+
if config.json:
2947
print(json_lib.dumps(response.model_dump(), default=datetime_serializer, indent=2))
3048
return
3149

3250
print("Successfully marked endpoint as starting", file=sys.stderr)
3351
if wait:
3452
print("Waiting for endpoint to start...", file=sys.stderr)
35-
while (await client.endpoints.retrieve(endpoint_id)).state != "STARTED":
53+
while (await config.client.endpoints.retrieve(endpoint_id)).state != "STARTED":
3654
await asyncio.sleep(1)
3755
print("Endpoint started", file=sys.stderr)
3856
print(endpoint_id)

src/together/lib/cli/api/models/list.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,23 @@
55
from tabulate import tabulate
66

77
from cyclopts import Parameter
8+
from rich import print_json
89

9-
from together import AsyncTogether, omit
10-
from together.lib.utils.serializer import datetime_serializer
10+
11+
from together import omit
12+
from together.lib.cli.logger.config import CLIConfig
13+
from together._utils._json import openapi_dumps
1114

1215
async def list_(
1316
type_: Annotated[Optional[Literal["dedicated"]], Parameter(name="type")] = None,
14-
json_output: bool = False,
1517
*,
16-
client: Annotated[AsyncTogether, Parameter(parse=False)],
18+
config: Annotated[CLIConfig, Parameter(parse=False)],
1719
) -> None:
1820
"""List models."""
19-
models_list = await client.models.list(dedicated=type_ == "dedicated" if type_ else omit)
20-
21-
if json_output:
22-
import json as json_lib
21+
models_list = await config.client.models.list(dedicated=type_ == "dedicated" if type_ else omit)
2322

24-
items = [model.model_dump() for model in models_list]
25-
print(json_lib.dumps(items, indent=2, default=datetime_serializer))
23+
if config.json:
24+
print_json(openapi_dumps(models_list).decode())
2625
return
2726

2827
display_list: List[Dict[str, Any]] = []

src/together/lib/cli/api/models/upload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from together import TogetherError, omit
1010

11-
from together.lib.cli import Config
11+
from together.lib.cli.logger.config import CLIConfig
1212
from together.lib.cli.logger.prompt import PromptParameter
1313
from together.types.model_upload_response import ModelUploadResponse
1414
from together._utils._json import openapi_dumps
@@ -56,7 +56,7 @@ async def upload(
5656
base_model: Optional[str] = None,
5757
lora_model: Optional[str] = None,
5858
*,
59-
config: Annotated[Config, Parameter(parse=False)],
59+
config: Annotated[CLIConfig, Parameter(parse=False)],
6060
) -> None:
6161
"""Upload a custom model or adapter from Hugging Face or S3."""
6262
response: ModelUploadResponse = await config.client.models.upload(
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from together import AsyncTogether
2+
3+
class CLIConfig:
4+
client: AsyncTogether
5+
non_interactive: bool
6+
json: bool
7+
8+
def __init__(self, client: AsyncTogether, non_interactive: bool, json: bool):
9+
self.client = client
10+
self.non_interactive = non_interactive
11+
self.json = json

src/together/lib/cli/logger/prompt.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from rich import print
33

44
import questionary
5+
from together.lib.cli.logger.config import CLIConfig
6+
from rich.theme import Theme
7+
from rich.console import Console
58

69
custom_style_fancy = questionary.Style([
710
('qmark', 'fg:#caaef5 bold'), # token in front of the question
@@ -16,6 +19,41 @@
1619
('disabled', 'fg:#858585 italic') # disabled choices for select and checkbox prompts
1720
])
1821

22+
custom_theme = Theme({
23+
# Text styles
24+
"primary": "#caaef5", # Purple 300 ⭐ (lighter when bold)
25+
"secondary": "dim #caaef5", # Purple 500 ⭐ (mid-tone without bold)
26+
"accent": "#ff68d4", # Pink 500 ⭐
27+
"muted": "#98a0b3", # Grey 400 ⭐
28+
"dim": "dim #626b84", # Grey 600 ⭐
29+
30+
# Semantic styles
31+
"success": "bold #0dce74", # Green 400 ⭐
32+
"info": "#64afff", # Blue 500 ⭐
33+
"warning": "bold #ff815d", # Red 500 ⭐
34+
"error": "bold #c63800", # Red 700 ⭐
35+
36+
# UI elements
37+
"prompt": "#ba92ff", # Purple 500 ⭐ (no bold)
38+
"prompt.choices": "#caaef5", # Purple 300 ⭐
39+
"prompt.default": "dim #98a0b3", # Grey 400 ⭐
40+
41+
# Table styles
42+
"table.header": "#414858", # Purple 300 ⭐ (lighter when bold)
43+
"table.border": "#626b84", # Grey 600 ⭐
44+
"table.row": "#c4c9d4", # Grey 300 ⭐
45+
46+
# Progress/Loading
47+
"progress.description": "#caaef5", # Purple 300 ⭐
48+
"progress.percentage": "bold #caaef5", # Purple 300 ⭐ (lighter when bold)
49+
"bar.complete": "#ba92ff", # Purple 500 ⭐ (no bold)
50+
"bar.finished": "#0dce74", # Green 400 ⭐
51+
"bar.pulse": "#ff68d4", # Pink 500 ⭐
52+
})
53+
54+
console = Console(theme=custom_theme)
55+
56+
1957
class NameValidator(questionary.Validator):
2058
def validate(self, document):
2159
if document.text.count(" ") > 0:
@@ -27,18 +65,21 @@ def validate(self, document):
2765
class PromptParameter:
2866
message: str | None = None
2967
instructions: str | None = None
30-
choices: list[str] | None = None
68+
choices: list[str | questionary.Choice] | None = None
3169

32-
def __init__(self, message: str | None = None, instructions: str | None = None, choices: list[str] | None = None):
70+
def __init__(self, message: str | None = None, instructions: str | None = None, choices: list[str | questionary.Choice] | None = None):
3371
self.message = message
3472
self.instructions = instructions
3573
self.choices = choices
3674

75+
async def preprompt(self, _config: CLIConfig):
76+
pass
77+
3778
async def prompt(self, field: str) -> str:
3879
if self.instructions is not None:
3980
print(f"[dim]{self.instructions}[/dim]")
40-
81+
4182
if self.choices is not None:
42-
return await questionary.select(self.message or field, choices=self.choices, style=custom_style_fancy).unsafe_ask_async()
83+
return await questionary.select(self.message or field, choices=self.choices, style=custom_style_fancy, show_selected=True).unsafe_ask_async()
4384

4485
return await questionary.text(self.message or field, instruction="\n→", style=custom_style_fancy, validate=NameValidator).unsafe_ask_async()

0 commit comments

Comments
 (0)