Skip to content

Commit 4fdf6b1

Browse files
authored
Add Openai compatible Completion class (#60)
* Add openai compatible Completion class * Add pydantic dependency
1 parent c562a64 commit 4fdf6b1

File tree

8 files changed

+444
-115
lines changed

8 files changed

+444
-115
lines changed

poetry.lock

Lines changed: 255 additions & 104 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ requests = "^2.31.0"
2626
tqdm = "^4.66.1"
2727
sseclient-py = "^1.7.2"
2828
tabulate = "^0.9.0"
29+
pydantic = "^2.5.0"
2930

3031
[tool.poetry.group.quality]
3132
optional = true

src/together/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
min_samples = 100
3434

35-
from .complete import Complete
35+
from .complete import Complete, Completion
3636
from .embeddings import Embeddings
3737
from .error import *
3838
from .files import Files
@@ -54,6 +54,7 @@
5454
"default_embedding_model",
5555
"Models",
5656
"Complete",
57+
"Completion",
5758
"Files",
5859
"Finetune",
5960
"Image",

src/together/commands/chat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def do_say(self, arg: str) -> None:
114114
top_k=self.args.top_k,
115115
repetition_penalty=self.args.repetition_penalty,
116116
):
117+
assert isinstance(token, str)
117118
print(token, end="", flush=True)
118119
output += token
119120
except together.AuthenticationError:

src/together/commands/complete.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def _run_complete(args: argparse.Namespace) -> None:
146146
except together.AuthenticationError:
147147
logger.critical(together.MISSING_API_KEY_MESSAGE)
148148
exit(0)
149+
assert isinstance(response, dict)
149150
no_streamer(args, response)
150151
else:
151152
try:

src/together/complete.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import json
2-
from typing import Any, Dict, Iterator, List, Optional
2+
from typing import Any, Dict, Iterator, List, Optional, Union
33

44
import together
5+
from together.types import TogetherResponse
56
from together.utils import create_post_request, get_logger, sse_client
67

78

@@ -21,7 +22,9 @@ def create(
2122
top_k: Optional[int] = 50,
2223
repetition_penalty: Optional[float] = None,
2324
logprobs: Optional[int] = None,
24-
) -> Dict[str, Any]:
25+
api_key: Optional[str] = None,
26+
cast: bool = False,
27+
) -> Union[Dict[str, Any], TogetherResponse]:
2528
if model == "":
2629
model = together.default_text_model
2730

@@ -39,14 +42,18 @@ def create(
3942

4043
# send request
4144
response = create_post_request(
42-
url=together.api_base_complete, json=parameter_payload
45+
url=together.api_base_complete, json=parameter_payload, api_key=api_key
4346
)
4447

4548
try:
4649
response_json = dict(response.json())
4750

4851
except Exception as e:
4952
raise together.JSONError(e, http_status=response.status_code)
53+
54+
if cast:
55+
return TogetherResponse(**response_json)
56+
5057
return response_json
5158

5259
@classmethod
@@ -55,13 +62,15 @@ def create_streaming(
5562
prompt: str,
5663
model: Optional[str] = "",
5764
max_tokens: Optional[int] = 128,
58-
stop: Optional[str] = None,
65+
stop: Optional[List[str]] = None,
5966
temperature: Optional[float] = 0.7,
6067
top_p: Optional[float] = 0.7,
6168
top_k: Optional[int] = 50,
6269
repetition_penalty: Optional[float] = None,
6370
raw: Optional[bool] = False,
64-
) -> Iterator[str]:
71+
api_key: Optional[str] = None,
72+
cast: Optional[bool] = False,
73+
) -> Union[Iterator[str], Iterator[TogetherResponse]]:
6574
"""
6675
Prints streaming responses and returns the completed text.
6776
"""
@@ -83,19 +92,25 @@ def create_streaming(
8392

8493
# send request
8594
response = create_post_request(
86-
url=together.api_base_complete, json=parameter_payload, stream=True
95+
url=together.api_base_complete,
96+
json=parameter_payload,
97+
api_key=api_key,
98+
stream=True,
8799
)
88100

89101
output = ""
90102
client = sse_client(response)
91103
for event in client.events():
92-
if raw:
104+
if cast:
105+
if event.data != "[DONE]":
106+
yield TogetherResponse(**json.loads(event.data))
107+
elif raw:
93108
yield str(event.data)
94109
elif event.data != "[DONE]":
95110
json_response = dict(json.loads(event.data))
96111
if "error" in json_response.keys():
97112
raise together.ResponseError(
98-
json_response["error"]["error"],
113+
json_response["error"],
99114
request_id=json_response["error"]["request_id"],
100115
)
101116
elif "choices" in json_response.keys():
@@ -106,3 +121,50 @@ def create_streaming(
106121
raise together.ResponseError(
107122
f"Unknown error occured. Received unhandled response: {event.data}"
108123
)
124+
125+
126+
class Completion:
127+
@classmethod
128+
def create(
129+
self,
130+
prompt: str,
131+
model: Optional[str] = "",
132+
max_tokens: Optional[int] = 128,
133+
stop: Optional[List[str]] = [],
134+
temperature: Optional[float] = 0.7,
135+
top_p: Optional[float] = 0.7,
136+
top_k: Optional[int] = 50,
137+
repetition_penalty: Optional[float] = None,
138+
logprobs: Optional[int] = None,
139+
api_key: Optional[str] = None,
140+
stream: bool = False,
141+
) -> Union[
142+
TogetherResponse, Iterator[TogetherResponse], Iterator[str], Dict[str, Any]
143+
]:
144+
if stream:
145+
return Complete.create_streaming(
146+
prompt=prompt,
147+
model=model,
148+
max_tokens=max_tokens,
149+
stop=stop,
150+
temperature=temperature,
151+
top_p=top_p,
152+
top_k=top_k,
153+
repetition_penalty=repetition_penalty,
154+
api_key=api_key,
155+
cast=True,
156+
)
157+
else:
158+
return Complete.create(
159+
prompt=prompt,
160+
model=model,
161+
max_tokens=max_tokens,
162+
stop=stop,
163+
temperature=temperature,
164+
top_p=top_p,
165+
top_k=top_k,
166+
repetition_penalty=repetition_penalty,
167+
logprobs=logprobs,
168+
api_key=api_key,
169+
cast=True,
170+
)

src/together/types.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import typing
2+
from enum import Enum
3+
from typing import Any, Dict, List, Optional
4+
5+
from pydantic import BaseModel
6+
7+
8+
# Decoder input tokens
9+
class InputToken(BaseModel):
10+
# Token ID from the model tokenizer
11+
id: int
12+
# Token text
13+
text: str
14+
# Logprob
15+
# Optional since the logprob of the first token cannot be computed
16+
logprob: Optional[float]
17+
18+
19+
# Generated tokens
20+
class Token(BaseModel):
21+
# Token ID
22+
id: int
23+
# Logprob
24+
logprob: Optional[float]
25+
# Is the token a special token
26+
# Can be used to ignore tokens when concatenating
27+
special: bool
28+
29+
30+
# Generation finish reason
31+
class FinishReason(str, Enum):
32+
# number of generated tokens == `max_new_tokens`
33+
Length = "length"
34+
# the model generated its end of sequence token
35+
EndOfSequenceToken = "eos_token"
36+
# the model generated a text included in `stop_sequences`
37+
StopSequence = "stop_sequence"
38+
39+
40+
# Additional sequences when using the `best_of` parameter
41+
class BestOfSequence(BaseModel):
42+
# Generated text
43+
generated_text: str
44+
# Generation finish reason
45+
finish_reason: FinishReason
46+
# Number of generated tokens
47+
generated_tokens: int
48+
# Sampling seed if sampling was activated
49+
seed: Optional[int]
50+
# Decoder input tokens, empty if decoder_input_details is False
51+
prefill: List[InputToken]
52+
# Generated tokens
53+
tokens: List[Token]
54+
55+
56+
# `generate` details
57+
class Details(BaseModel):
58+
# Generation finish reason
59+
finish_reason: FinishReason
60+
# Number of generated tokens
61+
generated_tokens: int
62+
# Sampling seed if sampling was activated
63+
seed: Optional[int]
64+
# Decoder input tokens, empty if decoder_input_details is False
65+
prefill: List[InputToken]
66+
# Generated tokens
67+
tokens: List[Token]
68+
# Additional sequences when using the `best_of` parameter
69+
best_of_sequences: Optional[List[BestOfSequence]]
70+
71+
72+
# `generate` return value
73+
class Response(BaseModel):
74+
# Generated text
75+
generated_text: str
76+
# Generation details
77+
details: Details
78+
79+
80+
class Choice(BaseModel):
81+
# Generated text
82+
text: str
83+
finish_reason: Optional[FinishReason] = None
84+
logprobs: Optional[List[float]] = None
85+
86+
87+
# `generate_stream` details
88+
class StreamDetails(BaseModel):
89+
# Number of generated tokens
90+
generated_tokens: int
91+
# Sampling seed if sampling was activated
92+
seed: Optional[int]
93+
94+
95+
# `generate_stream` return value
96+
class TogetherResponse(BaseModel):
97+
choices: Optional[List[Choice]] = None
98+
id: Optional[str] = None
99+
token: Optional[Token] = None
100+
error: Optional[str] = None
101+
error_type: Optional[str] = None
102+
generated_text: Optional[str] = None
103+
# Generation details
104+
# Only available when the generation is finished
105+
details: Optional[StreamDetails] = None
106+
107+
def __init__(self, **kwargs: Optional[Dict[str, Any]]) -> None:
108+
if kwargs.get("output"):
109+
kwargs["choices"] = typing.cast(Dict[str, Any], kwargs["output"])["choices"]
110+
kwargs["details"] = kwargs.get("details")
111+
super().__init__(**kwargs)

src/together/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,14 @@ def create_post_request(
8181
json: Optional[Dict[Any, Any]] = None,
8282
stream: Optional[bool] = False,
8383
check_auth: Optional[bool] = True,
84+
api_key: Optional[str] = None,
8485
) -> requests.Response:
85-
if check_auth:
86+
if check_auth and api_key is None:
8687
verify_api_key()
8788

8889
if not headers:
8990
headers = {
90-
"Authorization": f"Bearer {together.api_key}",
91+
"Authorization": f"Bearer {api_key or together.api_key}",
9192
"Content-Type": "application/json",
9293
"User-Agent": together.user_agent,
9394
}

0 commit comments

Comments
 (0)