88import uuid
99from abc import ABC
1010from dataclasses import dataclass
11- from typing import Any , Dict , List , Optional
11+ from typing import Any , Dict , List , Optional , Union
1212
1313from build .utils import device_sync
1414
@@ -87,31 +87,39 @@ class StreamOptions:
8787 include_usage : bool = False
8888
8989
90+ @dataclass
91+ class ResponseFormat :
92+ type : Optional [str ] = None
93+
94+
9095@dataclass
9196class CompletionRequest :
9297 """A full chat completion request.
9398
9499 See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details.
95100 """
96101
102+ messages : List [_AbstractMessage ]
97103 model : str
98- prompt : str
99- messages : Optional [List [_AbstractMessage ]]
100- frequency_penalty : float = 0.0
101- temperature : float = 0.0
102- stop : Optional [List [str ]] = None
103- stream : bool = False
104- stream_options : Optional [StreamOptions ] = None
105- echo : bool = False
106- frequency_penalty : float = 0.0
107- guided_decode_json_schema : str = None
108- guided_decode_json_schema_path : str = None
104+ frequency_penalty : float = 0.0 # unimplemented
105+ logit_bias : Optional [Dict [str , float ]] = None # unimplemented
106+ logprobs : Optional [bool ] = None # unimplemented
107+ top_logprobs : Optional [int ] = None # unimplemented
108+ max_tokens : Optional [int ] = None # unimplemented
109109 n : int = 1
110- presence_penalty : float = 0
111- logit_bias : Optional [Dict [str , float ]] = None
112- logprobs : Optional [bool ] = None
113- top_logprobs : Optional [int ] = None
114- max_tokens : Optional [int ] = None
110+ presence_penalty : float = 0 # unimplemented
111+ response_format : Optional [ResponseFormat ] = None # unimplemented
112+ seed : Optional [int ] = None # unimplemented
113+ service_tier : Optional [str ] = None # unimplemented
114+ stop : Optional [List [str ]] = None # unimplemented
115+ stream : bool = False
116+ stream_options : Optional [StreamOptions ] = None # unimplemented
117+ temperature : Optional [float ] = 1.0 # unimplemented
118+ top_p : Optional [float ] = 1.0 # unimplemented
119+ tools : Optional [List [Any ]] = None # unimplemented
120+ tool_choice : Optional [Union [str , Any ]] = None # unimplemented
121+ parallel_tool_calls : Optional [bool ] = None # unimplemented
122+ user : Optional [str ] = None # unimplemented
115123
116124
117125@dataclass
@@ -121,10 +129,10 @@ class CompletionChoice:
121129 See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
122130 """
123131
124- finish_reason : str
125132 index : int
126133 message : AssistantMessage
127- logprobs : Optional [List [Any ]]
134+ finish_reason : str = None
135+ logprobs : Optional [List [Any ]] = None
128136
129137
130138@dataclass
@@ -151,9 +159,9 @@ class CompletionResponse:
151159 created : int
152160 model : str
153161 system_fingerprint : str
154- usage : UsageStats
155- object : str = "chat.completion"
156162 service_tier : Optional [str ] = None
163+ usage : Optional [UsageStats ] = None
164+ object : str = "chat.completion"
157165
158166
159167@dataclass
@@ -193,8 +201,8 @@ class CompletionResponseChunk:
193201 created : int
194202 model : str
195203 system_fingerprint : str
196- object : str = "chat.completion.chunk"
197204 service_tier : Optional [str ] = None
205+ object : str = "chat.completion.chunk"
198206 usage : Optional [UsageStats ] = None
199207
200208
@@ -220,10 +228,27 @@ def __init__(self, *args, **kwargs):
220228 if self .draft_model is not None
221229 else self .model .config .max_seq_length
222230 )
231+ # The System fingerprint is a unique identifier for the model and its configuration.
232+ # Currently, this is not implemented in a
233+ self .system_fingerprint = (
234+ self .builder_args .device + type (self .builder_args .precision ).__name__
235+ )
223236
224- def completion (self , completion_request : CompletionRequest ):
237+ def chunked_completion (self , completion_request : CompletionRequest ):
225238 """Handle a chat completion request and yield a chunked response.
226239
240+ ** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
241+ Current treatment of parameters is described below.
242+
243+ - messages: The server consumes the final element of the array as the prompt.
244+ - model: This has no impact on the server state, i.e. changing the model in the request
245+ will not change which model is responding. Instead, use the --model flag to seelect the model when starting the server.
246+ - temperature: This is used to control the randomness of the response.
247+ - system_fingerprint: A unique identifier for the model and its configuration. Currently unimplemented - subject to change.
248+
249+ See https://github.com/pytorch/torchchat/issues/973 for more details.
250+
251+
227252 Args:
228253 completion_request: Request object with prompt and other parameters.
229254
@@ -235,13 +260,16 @@ def completion(self, completion_request: CompletionRequest):
235260
236261 # Initialize counters for chunk responses and encode the prompt.
237262 id = str (uuid .uuid4 ())
263+
238264 idx = 0
239265 buffer = []
240266 encoded = self .encode_tokens (
241- completion_request .prompt , bos = True , device = self .builder_args .device
267+ completion_request .messages [- 1 ].get ("content" ),
268+ bos = True ,
269+ device = self .builder_args .device ,
242270 )
243271 generator_args = GeneratorArgs (
244- completion_request .prompt ,
272+ completion_request .messages [ - 1 ]. get ( "content" ) ,
245273 encoded_prompt = encoded ,
246274 chat_mode = False ,
247275 )
@@ -291,21 +319,45 @@ def callback(x, *, done_generating=False):
291319 choices = [choice_chunk ],
292320 created = int (time .time ()),
293321 model = completion_request .model ,
294- system_fingerprint = uuid . UUID ( int = uuid . getnode ()) ,
322+ system_fingerprint = self . system_fingerprint ,
295323 )
296324 yield chunk_response
297325 self .start_pos += y .size (0 )
298326 idx += 1
299327
300328 # Yield an ending chunk indicating the generation has completed.
301- end_chunk = CompletionChoiceChunk (ChunkDelta (None , None , None ), idx , "eos" )
329+ end_chunk = CompletionChoiceChunk (
330+ ChunkDelta (None , None , None ), idx , finish_reason = "stop"
331+ )
302332
303333 yield CompletionResponseChunk (
304334 id = str (id ),
305335 choices = [end_chunk ],
306336 created = int (time .time ()),
307337 model = completion_request .model ,
308- system_fingerprint = uuid .UUID (int = uuid .getnode ()),
338+ system_fingerprint = self .system_fingerprint ,
339+ )
340+
341+ def sync_completion (self , request : CompletionRequest ):
342+ """Handle a chat completion request and yield a single, non-chunked response"""
343+ output = ""
344+ for chunk in self .chunked_completion (request ):
345+ if not chunk .choices [0 ].finish_reason :
346+ output += chunk .choices [0 ].delta .content
347+
348+ message = AssistantMessage (content = output )
349+ return CompletionResponse (
350+ id = str (uuid .uuid4 ()),
351+ choices = [
352+ CompletionChoice (
353+ finish_reason = "stop" ,
354+ index = 0 ,
355+ message = message ,
356+ )
357+ ],
358+ created = int (time .time ()),
359+ model = request .model ,
360+ system_fingerprint = self .system_fingerprint ,
309361 )
310362
311363 def _callback (self , x , * , buffer , done_generating ):
0 commit comments