Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit dea8d60

Browse files
authored
OpenAI API JSON formatted (#995)
* Add warning comments referring to unimplemented functionality * JSON formatted response using OpenAI API types for server completion requests * Add models endpoint (#1000)
1 parent 912917f commit dea8d60

File tree

4 files changed

+242
-74
lines changed

4 files changed

+242
-74
lines changed

api/api.py

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import uuid
99
from abc import ABC
1010
from dataclasses import dataclass
11-
from typing import Any, Dict, List, Optional
11+
from typing import Any, Dict, List, Optional, Union
1212

1313
from build.utils import device_sync
1414

@@ -87,31 +87,39 @@ class StreamOptions:
8787
include_usage: bool = False
8888

8989

90+
@dataclass
91+
class ResponseFormat:
92+
type: Optional[str] = None
93+
94+
9095
@dataclass
9196
class CompletionRequest:
9297
"""A full chat completion request.
9398
9499
See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details.
95100
"""
96101

102+
messages: List[_AbstractMessage]
97103
model: str
98-
prompt: str
99-
messages: Optional[List[_AbstractMessage]]
100-
frequency_penalty: float = 0.0
101-
temperature: float = 0.0
102-
stop: Optional[List[str]] = None
103-
stream: bool = False
104-
stream_options: Optional[StreamOptions] = None
105-
echo: bool = False
106-
frequency_penalty: float = 0.0
107-
guided_decode_json_schema: str = None
108-
guided_decode_json_schema_path: str = None
104+
frequency_penalty: float = 0.0 # unimplemented
105+
logit_bias: Optional[Dict[str, float]] = None # unimplemented
106+
logprobs: Optional[bool] = None # unimplemented
107+
top_logprobs: Optional[int] = None # unimplemented
108+
max_tokens: Optional[int] = None # unimplemented
109109
n: int = 1
110-
presence_penalty: float = 0
111-
logit_bias: Optional[Dict[str, float]] = None
112-
logprobs: Optional[bool] = None
113-
top_logprobs: Optional[int] = None
114-
max_tokens: Optional[int] = None
110+
presence_penalty: float = 0 # unimplemented
111+
response_format: Optional[ResponseFormat] = None # unimplemented
112+
seed: Optional[int] = None # unimplemented
113+
service_tier: Optional[str] = None # unimplemented
114+
stop: Optional[List[str]] = None # unimplemented
115+
stream: bool = False
116+
stream_options: Optional[StreamOptions] = None # unimplemented
117+
temperature: Optional[float] = 1.0 # unimplemented
118+
top_p: Optional[float] = 1.0 # unimplemented
119+
tools: Optional[List[Any]] = None # unimplemented
120+
tool_choice: Optional[Union[str, Any]] = None # unimplemented
121+
parallel_tool_calls: Optional[bool] = None # unimplemented
122+
user: Optional[str] = None # unimplemented
115123

116124

117125
@dataclass
@@ -121,10 +129,10 @@ class CompletionChoice:
121129
See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
122130
"""
123131

124-
finish_reason: str
125132
index: int
126133
message: AssistantMessage
127-
logprobs: Optional[List[Any]]
134+
finish_reason: str = None
135+
logprobs: Optional[List[Any]] = None
128136

129137

130138
@dataclass
@@ -151,9 +159,9 @@ class CompletionResponse:
151159
created: int
152160
model: str
153161
system_fingerprint: str
154-
usage: UsageStats
155-
object: str = "chat.completion"
156162
service_tier: Optional[str] = None
163+
usage: Optional[UsageStats] = None
164+
object: str = "chat.completion"
157165

158166

159167
@dataclass
@@ -193,8 +201,8 @@ class CompletionResponseChunk:
193201
created: int
194202
model: str
195203
system_fingerprint: str
196-
object: str = "chat.completion.chunk"
197204
service_tier: Optional[str] = None
205+
object: str = "chat.completion.chunk"
198206
usage: Optional[UsageStats] = None
199207

200208

@@ -220,10 +228,27 @@ def __init__(self, *args, **kwargs):
220228
if self.draft_model is not None
221229
else self.model.config.max_seq_length
222230
)
231+
# The System fingerprint is a unique identifier for the model and its configuration.
232+
# Currently, this is not implemented in a
233+
self.system_fingerprint = (
234+
self.builder_args.device + type(self.builder_args.precision).__name__
235+
)
223236

224-
def completion(self, completion_request: CompletionRequest):
237+
def chunked_completion(self, completion_request: CompletionRequest):
225238
"""Handle a chat completion request and yield a chunked response.
226239
240+
** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
241+
Current treatment of parameters is described below.
242+
243+
- messages: The server consumes the final element of the array as the prompt.
244+
- model: This has no impact on the server state, i.e. changing the model in the request
245+
will not change which model is responding. Instead, use the --model flag to seelect the model when starting the server.
246+
- temperature: This is used to control the randomness of the response.
247+
- system_fingerprint: A unique identifier for the model and its configuration. Currently unimplemented - subject to change.
248+
249+
See https://github.com/pytorch/torchchat/issues/973 for more details.
250+
251+
227252
Args:
228253
completion_request: Request object with prompt and other parameters.
229254
@@ -235,13 +260,16 @@ def completion(self, completion_request: CompletionRequest):
235260

236261
# Initialize counters for chunk responses and encode the prompt.
237262
id = str(uuid.uuid4())
263+
238264
idx = 0
239265
buffer = []
240266
encoded = self.encode_tokens(
241-
completion_request.prompt, bos=True, device=self.builder_args.device
267+
completion_request.messages[-1].get("content"),
268+
bos=True,
269+
device=self.builder_args.device,
242270
)
243271
generator_args = GeneratorArgs(
244-
completion_request.prompt,
272+
completion_request.messages[-1].get("content"),
245273
encoded_prompt=encoded,
246274
chat_mode=False,
247275
)
@@ -291,21 +319,45 @@ def callback(x, *, done_generating=False):
291319
choices=[choice_chunk],
292320
created=int(time.time()),
293321
model=completion_request.model,
294-
system_fingerprint=uuid.UUID(int=uuid.getnode()),
322+
system_fingerprint=self.system_fingerprint,
295323
)
296324
yield chunk_response
297325
self.start_pos += y.size(0)
298326
idx += 1
299327

300328
# Yield an ending chunk indicating the generation has completed.
301-
end_chunk = CompletionChoiceChunk(ChunkDelta(None, None, None), idx, "eos")
329+
end_chunk = CompletionChoiceChunk(
330+
ChunkDelta(None, None, None), idx, finish_reason="stop"
331+
)
302332

303333
yield CompletionResponseChunk(
304334
id=str(id),
305335
choices=[end_chunk],
306336
created=int(time.time()),
307337
model=completion_request.model,
308-
system_fingerprint=uuid.UUID(int=uuid.getnode()),
338+
system_fingerprint=self.system_fingerprint,
339+
)
340+
341+
def sync_completion(self, request: CompletionRequest):
342+
"""Handle a chat completion request and yield a single, non-chunked response"""
343+
output = ""
344+
for chunk in self.chunked_completion(request):
345+
if not chunk.choices[0].finish_reason:
346+
output += chunk.choices[0].delta.content
347+
348+
message = AssistantMessage(content=output)
349+
return CompletionResponse(
350+
id=str(uuid.uuid4()),
351+
choices=[
352+
CompletionChoice(
353+
finish_reason="stop",
354+
index=0,
355+
message=message,
356+
)
357+
],
358+
created=int(time.time()),
359+
model=request.model,
360+
system_fingerprint=self.system_fingerprint,
309361
)
310362

311363
def _callback(self, x, *, buffer, done_generating):

api/models.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
9+
from dataclasses import dataclass
10+
from pwd import getpwuid
11+
from typing import List, Union
12+
13+
from download import is_model_downloaded, load_model_configs
14+
15+
"""Helper functions for the OpenAI API Models endpoint.
16+
17+
See https://platform.openai.com/docs/api-reference/models for the full specification and details.
18+
Please create an issue if anything doesn't match the specification.
19+
"""
20+
21+
22+
@dataclass
23+
class ModelInfo:
24+
"""The Model object per the OpenAI API specification containing information about a model.
25+
26+
See https://platform.openai.com/docs/api-reference/models/object for more details.
27+
"""
28+
29+
id: str
30+
created: int
31+
owner: str
32+
object: str = "model"
33+
34+
35+
@dataclass
36+
class ModelInfoList:
37+
"""A list of ModelInfo objects."""
38+
39+
data: List[ModelInfo]
40+
object: str = "list"
41+
42+
43+
def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]:
44+
"""Implementation of the OpenAI API Retrieve Model endpoint.
45+
46+
See https://platform.openai.com/docs/api-reference/models/retrieve
47+
48+
Inputs:
49+
args: command line arguments
50+
model_id: the id of the model requested
51+
52+
Returns:
53+
ModelInfo describing the specified if it is downloaded, None otherwise.
54+
"""
55+
if model_config := load_model_configs().get(model_id):
56+
if is_model_downloaded(model_id, args.model_directory):
57+
path = args.model_directory / model_config.name
58+
created = int(os.path.getctime(path))
59+
owner = getpwuid(os.stat(path).st_uid).pw_name
60+
61+
return ModelInfo(id=model_config.name, created=created, owner=owner)
62+
return None
63+
return None
64+
65+
66+
def get_model_info_list(args) -> ModelInfo:
67+
"""Implementation of the OpenAI API List Models endpoint.
68+
69+
See https://platform.openai.com/docs/api-reference/models/list
70+
71+
Inputs:
72+
args: command line arguments
73+
74+
Returns:
75+
ModelInfoList describing all downloaded models.
76+
"""
77+
data = []
78+
for model_id, model_config in load_model_configs().items():
79+
if is_model_downloaded(model_id, args.model_directory):
80+
path = args.model_directory / model_config.name
81+
created = int(os.path.getctime(path))
82+
owner = getpwuid(os.stat(path).st_uid).pw_name
83+
84+
data.append(ModelInfo(id=model_config.name, created=created, owner=owner))
85+
response = ModelInfoList(data=data)
86+
return response

generate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,11 +452,15 @@ def generate(
452452
sequential_prefill=True,
453453
callback=lambda x: x,
454454
max_seq_length: int,
455+
seed: Optional[int] = None,
455456
**sampling_kwargs,
456457
) -> torch.Tensor:
457458
"""
458459
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
459460
"""
461+
if seed:
462+
torch.manual_seed(seed)
463+
460464
is_speculative = draft_model is not None
461465
device, dtype = prompt.device, prompt.dtype
462466

0 commit comments

Comments
 (0)