11import json
2- from typing import Any , Dict , Iterator , List , Optional
2+ from typing import Any , Dict , Iterator , List , Optional , Union
33
44import together
5+ from together .types import TogetherResponse
56from together .utils import create_post_request , get_logger , sse_client
67
78
@@ -21,7 +22,9 @@ def create(
2122 top_k : Optional [int ] = 50 ,
2223 repetition_penalty : Optional [float ] = None ,
2324 logprobs : Optional [int ] = None ,
24- ) -> Dict [str , Any ]:
25+ api_key : Optional [str ] = None ,
26+ cast : bool = False ,
27+ ) -> Union [Dict [str , Any ], TogetherResponse ]:
2528 if model == "" :
2629 model = together .default_text_model
2730
@@ -39,14 +42,18 @@ def create(
3942
4043 # send request
4144 response = create_post_request (
42- url = together .api_base_complete , json = parameter_payload
45+ url = together .api_base_complete , json = parameter_payload , api_key = api_key
4346 )
4447
4548 try :
4649 response_json = dict (response .json ())
4750
4851 except Exception as e :
4952 raise together .JSONError (e , http_status = response .status_code )
53+
54+ if cast :
55+ return TogetherResponse (** response_json )
56+
5057 return response_json
5158
5259 @classmethod
@@ -55,13 +62,15 @@ def create_streaming(
5562 prompt : str ,
5663 model : Optional [str ] = "" ,
5764 max_tokens : Optional [int ] = 128 ,
58- stop : Optional [str ] = None ,
65+ stop : Optional [List [ str ] ] = None ,
5966 temperature : Optional [float ] = 0.7 ,
6067 top_p : Optional [float ] = 0.7 ,
6168 top_k : Optional [int ] = 50 ,
6269 repetition_penalty : Optional [float ] = None ,
6370 raw : Optional [bool ] = False ,
64- ) -> Iterator [str ]:
71+ api_key : Optional [str ] = None ,
72+ cast : Optional [bool ] = False ,
73+ ) -> Union [Iterator [str ], Iterator [TogetherResponse ]]:
6574 """
6675 Prints streaming responses and returns the completed text.
6776 """
@@ -83,19 +92,25 @@ def create_streaming(
8392
8493 # send request
8594 response = create_post_request (
86- url = together .api_base_complete , json = parameter_payload , stream = True
95+ url = together .api_base_complete ,
96+ json = parameter_payload ,
97+ api_key = api_key ,
98+ stream = True ,
8799 )
88100
89101 output = ""
90102 client = sse_client (response )
91103 for event in client .events ():
92- if raw :
104+ if cast :
105+ if event .data != "[DONE]" :
106+ yield TogetherResponse (** json .loads (event .data ))
107+ elif raw :
93108 yield str (event .data )
94109 elif event .data != "[DONE]" :
95110 json_response = dict (json .loads (event .data ))
96111 if "error" in json_response .keys ():
97112 raise together .ResponseError (
98- json_response ["error" ][ "error" ] ,
113+ json_response ["error" ],
99114 request_id = json_response ["error" ]["request_id" ],
100115 )
101116 elif "choices" in json_response .keys ():
@@ -106,3 +121,50 @@ def create_streaming(
106121 raise together .ResponseError (
107122 f"Unknown error occured. Received unhandled response: { event .data } "
108123 )
124+
125+
126+ class Completion :
127+ @classmethod
128+ def create (
129+ self ,
130+ prompt : str ,
131+ model : Optional [str ] = "" ,
132+ max_tokens : Optional [int ] = 128 ,
133+ stop : Optional [List [str ]] = [],
134+ temperature : Optional [float ] = 0.7 ,
135+ top_p : Optional [float ] = 0.7 ,
136+ top_k : Optional [int ] = 50 ,
137+ repetition_penalty : Optional [float ] = None ,
138+ logprobs : Optional [int ] = None ,
139+ api_key : Optional [str ] = None ,
140+ stream : bool = False ,
141+ ) -> Union [
142+ TogetherResponse , Iterator [TogetherResponse ], Iterator [str ], Dict [str , Any ]
143+ ]:
144+ if stream :
145+ return Complete .create_streaming (
146+ prompt = prompt ,
147+ model = model ,
148+ max_tokens = max_tokens ,
149+ stop = stop ,
150+ temperature = temperature ,
151+ top_p = top_p ,
152+ top_k = top_k ,
153+ repetition_penalty = repetition_penalty ,
154+ api_key = api_key ,
155+ cast = True ,
156+ )
157+ else :
158+ return Complete .create (
159+ prompt = prompt ,
160+ model = model ,
161+ max_tokens = max_tokens ,
162+ stop = stop ,
163+ temperature = temperature ,
164+ top_p = top_p ,
165+ top_k = top_k ,
166+ repetition_penalty = repetition_penalty ,
167+ logprobs = logprobs ,
168+ api_key = api_key ,
169+ cast = True ,
170+ )
0 commit comments