Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3888de3

Browse files
authored
Merge branch 'main' into benchmarking_script
2 parents eab209f + fcadb14 commit 3888de3

File tree

17 files changed

+219
-225
lines changed

17 files changed

+219
-225
lines changed

README.md

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,6 @@ This mode generates text based on an input prompt.
181181
python3 torchchat.py generate llama3.1 --prompt "write me a story about a boy and his bear"
182182
```
183183

184-
### Browser
185-
This mode allows you to chat with the model using a UI in your browser
186-
Running the command automatically open a tab in your browser.
187-
188-
[skip default]: begin
189-
190-
```
191-
streamlit run torchchat.py -- browser llama3.1
192-
```
193-
194184
[skip default]: end
195185

196186
### Server
@@ -252,6 +242,19 @@ curl http://127.0.0.1:5000/v1/chat \
252242

253243
</details>
254244

245+
### Browser
246+
This command opens a basic browser interface for local chat by querying a local server.
247+
248+
First, follow the steps in the Server section above to start a local server. Then, in another terminal, launch the interface. Running the following will open a tab in your browser.
249+
250+
[skip default]: begin
251+
252+
```
253+
streamlit run browser/browser.py
254+
```
255+
256+
Use the "Max Response Tokens" slider to limit the maximum number of tokens generated by the model for each response. Click the "Reset Chat" button to remove the message history and start a fresh chat.
257+
255258

256259
## Desktop/Server Execution
257260

api/api.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from dataclasses import dataclass
1111
from typing import Any, Dict, List, Optional, Union
1212

13+
import torch
14+
1315
from build.utils import device_sync
1416

1517
from generate import Generator, GeneratorArgs
@@ -123,6 +125,9 @@ class CompletionRequest:
123125
parallel_tool_calls: Optional[bool] = None # unimplemented - Assistant features
124126
user: Optional[str] = None # unimplemented
125127

128+
def __post_init__(self):
129+
self.stream = bool(self.stream)
130+
126131

127132
@dataclass
128133
class CompletionChoice:
@@ -202,7 +207,7 @@ class CompletionResponseChunk:
202207
choices: List[CompletionChoiceChunk]
203208
created: int
204209
model: str
205-
system_fingerprint: str
210+
system_fingerprint: Optional[str] = None
206211
service_tier: Optional[str] = None
207212
object: str = "chat.completion.chunk"
208213
usage: Optional[UsageStats] = None
@@ -222,7 +227,6 @@ def __init__(self, *args, **kwargs):
222227
"""
223228

224229
super().__init__(*args, **kwargs)
225-
self.start_pos = 0
226230
self.max_seq_length = (
227231
self.model.config.max_seq_length
228232
+ self.speculative_builder_args.speculate_k
@@ -257,20 +261,25 @@ def chunked_completion(self, completion_request: CompletionRequest):
257261
CompletionResponseChunk objects in response to completion_request as tokens are generated.
258262
259263
"""
260-
device_sync(device=self.builder_args.device)
261264

262265
# Initialize counters for chunk responses and encode the prompt.
263266
id = str(uuid.uuid4())
264267

265268
idx = 0
266-
buffer = []
267-
encoded = self.encode_tokens(
268-
completion_request.messages[-1].get("content"),
269-
bos=True,
270-
device=self.builder_args.device,
269+
tokens = self.chat_formatter.encode_dialog_prompt(
270+
dialog=[
271+
{"role": message["role"], "content": message["content"]}
272+
for message in completion_request.messages
273+
]
271274
)
275+
276+
encoded = torch.tensor(tokens, dtype=torch.int, device=self.builder_args.device)
277+
print(self.tokenizer.decode(tokens))
278+
279+
start_pos = 0
280+
272281
generator_args = GeneratorArgs(
273-
completion_request.messages[-1].get("content"),
282+
None,
274283
max_new_tokens=(
275284
int(completion_request.max_tokens)
276285
if completion_request.max_tokens
@@ -279,33 +288,39 @@ def chunked_completion(self, completion_request: CompletionRequest):
279288
encoded_prompt=encoded,
280289
temperature=float(completion_request.temperature),
281290
chat_mode=False,
291+
sequential_prefill=True,
282292
)
283293

284294
def callback(x, *, done_generating=False):
285295
return self._callback(
286296
x,
287-
buffer=buffer,
297+
buffer=None,
288298
done_generating=done_generating,
289299
)
290300

301+
device_sync(device=self.builder_args.device)
302+
291303
# Process each token, metrics tuple yielded by Generator.generate.
292304
for y, _ in self.generate(
293-
self.model,
294-
encoded,
295-
generator_args.max_new_tokens,
305+
model=self.model,
306+
prompt=encoded,
307+
max_new_tokens=generator_args.max_new_tokens,
296308
draft_model=self.draft_model,
297309
speculate_k=generator_args.speculate_k,
298310
chat_mode=generator_args.chat_mode,
299311
callback=callback,
300312
temperature=generator_args.temperature,
301313
top_k=generator_args.top_k,
302314
sequential_prefill=generator_args.sequential_prefill,
303-
start_pos=self.start_pos,
315+
start_pos=start_pos,
304316
max_seq_length=self.max_seq_length,
305-
seed=int(completion_request.seed),
317+
seed=int(completion_request.seed or 0),
306318
):
307319
if y is None:
308320
continue
321+
elif y.item() == self.tokenizer.eos_id:
322+
# Stop generation if the EOS token is generated.
323+
break
309324

310325
# Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token.
311326
content = "".join(
@@ -321,16 +336,17 @@ def callback(x, *, done_generating=False):
321336
choice_chunk = CompletionChoiceChunk(
322337
delta=chunk_delta,
323338
index=idx,
339+
finish_reason=None,
324340
)
325341
chunk_response = CompletionResponseChunk(
326-
id=str(id),
342+
id="chatcmpl-" + str(id),
327343
choices=[choice_chunk],
328344
created=int(time.time()),
329345
model=completion_request.model,
330346
system_fingerprint=self.system_fingerprint,
331347
)
332348
yield chunk_response
333-
self.start_pos += y.size(0)
349+
start_pos += y.size(0)
334350
idx += 1
335351

336352
# Yield an ending chunk indicating the generation has completed.
@@ -339,7 +355,7 @@ def callback(x, *, done_generating=False):
339355
)
340356

341357
yield CompletionResponseChunk(
342-
id=str(id),
358+
id="chatcmpl-" + str(id),
343359
choices=[end_chunk],
344360
created=int(time.time()),
345361
model=completion_request.model,
@@ -355,7 +371,7 @@ def sync_completion(self, request: CompletionRequest):
355371

356372
message = AssistantMessage(content=output)
357373
return CompletionResponse(
358-
id=str(uuid.uuid4()),
374+
id="chatcmpl-" + str(uuid.uuid4()),
359375
choices=[
360376
CompletionChoice(
361377
finish_reason="stop",
@@ -369,10 +385,4 @@ def sync_completion(self, request: CompletionRequest):
369385
)
370386

371387
def _callback(self, x, *, buffer, done_generating):
372-
period_id = self.tokenizer.encode(".")[0]
373-
buffer.append(self.tokenizer.decode([period_id] + x.tolist())[1:])
374-
if (
375-
self.is_llama3_model
376-
and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]
377-
):
378-
buffer = buffer[:-1] # drop the eot_id from the output buffer
388+
pass

browser/browser.py

Lines changed: 61 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,67 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# All rights reserved.
3-
4-
# This source code is licensed under the license found in the
5-
# LICENSE file in the root directory of this source tree.
6-
71
import time
8-
92
import streamlit as st
10-
from api.api import CompletionRequest, OpenAiApiGenerator
11-
12-
from build.builder import BuilderArgs, TokenizerArgs
13-
14-
from generate import GeneratorArgs
15-
16-
17-
def main(args):
18-
builder_args = BuilderArgs.from_args(args)
19-
speculative_builder_args = BuilderArgs.from_speculative_args(args)
20-
tokenizer_args = TokenizerArgs.from_args(args)
21-
generator_args = GeneratorArgs.from_args(args)
22-
generator_args.chat_mode = False
23-
24-
@st.cache_resource
25-
def initialize_generator() -> OpenAiApiGenerator:
26-
return OpenAiApiGenerator(
27-
builder_args,
28-
speculative_builder_args,
29-
tokenizer_args,
30-
generator_args,
31-
args.profile,
32-
args.quantize,
33-
args.draft_quantize,
34-
)
35-
36-
gen = initialize_generator()
37-
38-
st.title("torchchat")
39-
40-
# Initialize chat history
41-
if "messages" not in st.session_state:
42-
st.session_state.messages = []
43-
44-
# Display chat messages from history on app rerun
45-
for message in st.session_state.messages:
46-
with st.chat_message(message["role"]):
47-
st.markdown(message["content"])
48-
49-
# Accept user input
50-
if prompt := st.chat_input("What is up?"):
51-
# Add user message to chat history
52-
st.session_state.messages.append({"role": "user", "content": prompt})
53-
# Display user message in chat message container
54-
with st.chat_message("user"):
55-
st.markdown(prompt)
56-
57-
# Display assistant response in chat message container
58-
with st.chat_message("assistant"), st.status(
59-
"Generating... ", expanded=True
60-
) as status:
61-
62-
req = CompletionRequest(
63-
model=gen.builder_args.checkpoint_path,
64-
prompt=prompt,
65-
temperature=generator_args.temperature,
66-
messages=[],
3+
from openai import OpenAI
4+
5+
st.title("torchchat")
6+
7+
start_state = [
8+
{
9+
"role": "system",
10+
"content": "You're an assistant. Answer questions directly, be brief, and have fun.",
11+
},
12+
{"role": "assistant", "content": "How can I help you?"},
13+
]
14+
15+
with st.sidebar:
16+
response_max_tokens = st.slider(
17+
"Max Response Tokens", min_value=10, max_value=1000, value=250, step=10
18+
)
19+
if st.button("Reset Chat", type="primary"):
20+
st.session_state["messages"] = start_state
21+
22+
if "messages" not in st.session_state:
23+
st.session_state["messages"] = start_state
24+
25+
26+
for msg in st.session_state.messages:
27+
st.chat_message(msg["role"]).write(msg["content"])
28+
29+
if prompt := st.chat_input():
30+
client = OpenAI(
31+
base_url="http://127.0.0.1:5000/v1",
32+
api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
33+
)
34+
35+
st.session_state.messages.append({"role": "user", "content": prompt})
36+
st.chat_message("user").write(prompt)
37+
38+
with st.chat_message("assistant"), st.status(
39+
"Generating... ", expanded=True
40+
) as status:
41+
42+
def get_streamed_completion(completion_generator):
43+
start = time.time()
44+
tokcount = 0
45+
for chunk in completion_generator:
46+
tokcount += 1
47+
yield chunk.choices[0].delta.content
48+
49+
status.update(
50+
label="Done, averaged {:.2f} tokens/second".format(
51+
tokcount / (time.time() - start)
52+
),
53+
state="complete",
6754
)
6855

69-
def unwrap(completion_generator):
70-
start = time.time()
71-
tokcount = 0
72-
for chunk_response in completion_generator:
73-
content = chunk_response.choices[0].delta.content
74-
if not gen.is_llama3_model or content not in set(
75-
gen.tokenizer.special_tokens.keys()
76-
):
77-
yield content
78-
if content == gen.tokenizer.eos_id():
79-
yield "."
80-
tokcount += 1
81-
status.update(
82-
label="Done, averaged {:.2f} tokens/second".format(
83-
tokcount / (time.time() - start)
84-
),
85-
state="complete",
56+
response = st.write_stream(
57+
get_streamed_completion(
58+
client.chat.completions.create(
59+
model="llama3",
60+
messages=st.session_state.messages,
61+
max_tokens=response_max_tokens,
62+
stream=True,
8663
)
64+
)
65+
)[0]
8766

88-
response = st.write_stream(unwrap(gen.completion(req)))
89-
90-
# Add assistant response to chat history
91-
st.session_state.messages.append({"role": "assistant", "content": response})
67+
st.session_state.messages.append({"role": "assistant", "content": response})

build/builder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def _maybe_parellelize_model(
400400
if the user specifies using distributed inference. If not, this is a no-op.
401401
402402
Args:
403-
module (:class:`nn.Module`):
403+
model (:class:`nn.Module`):
404404
Module to be parallelized.
405405
builder_args (:class:`BuilderArgs`):
406406
Command args for model building.
@@ -440,6 +440,7 @@ def _initialize_model(
440440
quantize,
441441
tokenizer=None,
442442
max_seq_length=None,
443+
support_tensor_subclass: bool = True,
443444
):
444445
print("Loading model...")
445446

@@ -510,7 +511,13 @@ def _initialize_model(
510511
if quantize:
511512
print(f"Quantizing the model with: {quantize}")
512513
with measure_time("Time to quantize model: {time:.02f} seconds"):
513-
quantize_model(model, builder_args.device, quantize, tokenizer)
514+
quantize_model(
515+
model,
516+
builder_args.device,
517+
quantize,
518+
tokenizer,
519+
support_tensor_subclass,
520+
)
514521
device_sync(device=builder_args.device)
515522

516523
if builder_args.setup_caches:

0 commit comments

Comments
 (0)