Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit a190b0f

Browse files
committed
merge with unified model contruction pipeline
2 parents 43dfdc7 + a356897 commit a190b0f

File tree

9 files changed

+292
-62
lines changed

9 files changed

+292
-62
lines changed

distributed/parallelize_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def apply_tp(
6262
# after we apply TP to the model. Because we don't want to change model code
6363
# when applying TP. We need to have change to ensure KVCache has the correct
6464
# size as k and v.
65-
model.model.config.n_local_heads = model.model.config.n_local_heads // tp_mesh.size()
65+
model.text_transformer_args.n_local_heads = model.text_transformer_args.n_local_heads // tp_mesh.size()
6666

6767
# Apply tensor parallelism to every transformer block
6868
for transformer_block in model.layers:

torchchat/cli/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def _initialize_model(
563563
model.setup_caches(
564564
max_batch_size=1,
565565
max_seq_length=max_seq_length
566-
or model.model.config.max_seq_length,
566+
or model.text_transformer_args.max_seq_length,
567567
)
568568

569569
model.to(dtype=builder_args.precision)

torchchat/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def export_for_server(
5454
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
5555
)
5656

57-
seq = Dim("seq", min=1, max=model.model.config.max_seq_length)
57+
seq = Dim("seq", min=1, max=model.text_transformer_args.max_seq_length)
5858
# Specify that the first dimension of each input is that batch size
5959
dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}}
6060
else:

torchchat/generate.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ def prefill(
364364
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
365365
# logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
366366
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
367+
elif self.model.config.model_type == ModelType.Flamingo:
368+
logits = model(x)
367369
else:
368370
# input_pos: [B, S]
369371
logits = model(x, input_pos)
@@ -383,11 +385,14 @@ def decode_one_token(
383385
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
384386
# input_pos: [B, 1]
385387
assert input_pos.shape[-1] == 1
386-
if model.config.model_type == ModelType.Flamingo and batch is not None:
387-
x = x.view(1, -1)
388-
logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:])
388+
x = x.view(1, -1)
389+
if model.config.model_type == ModelType.Flamingo:
390+
if batch is not None:
391+
logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:])
392+
else:
393+
logits = model(x)
389394
else:
390-
logits = model(x.view(1, -1), input_pos)
395+
logits = model(x, input_pos)
391396
# print(f"x: {x},\n input_pos: {input_pos}\n")
392397
return self.sample(logits, need_probs=need_probs, **sampling_kwargs)
393398

@@ -790,7 +795,7 @@ def chat(
790795

791796
# This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong
792797
# TODO: unify the max_seq_length config representation.
793-
text_transformer_args = getattr(self.model.model, "config", None)
798+
text_transformer_args = self.model.text_transformer_args
794799
max_seq_length = (
795800
text_transformer_args.max_seq_length if text_transformer_args else 2048
796801
)

torchchat/model.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,18 @@ def from_params(cls, params):
291291

292292
@dataclass
293293
class ModelArgs:
294+
"""
295+
A data class to describe the structure of a model.
296+
Attributes:
297+
model_type (ModelType): The type of the model. This attribute is used to categorize the model into different classes.
298+
transformer_args (Dict[str, Dict[str, Any]]): A dictionary containing the parameters for each transformer in the model.
299+
The outer dictionary has transformer names as keys and inner dictionaries as values. Each inner dictionary contains
300+
the parameter names and their corresponding values for the respective transformer.
301+
use_tiktoken (bool): A flag indicating whether to use TikToken as the tokenizer for the model.
302+
Note:
303+
It is recommended to use factory functions to create instances of this class instead of directly using the constructor.
304+
"""
305+
294306
model_type: ModelType
295307
transformer_args: Dict[str, Dict[str, Any]]
296308
use_tiktoken: bool
@@ -326,7 +338,7 @@ def from_params(cls, params_path):
326338
# The model params is in the transformer_args format
327339
# set the model_type to TextOnly and reformat the params
328340
model_type = ModelType.TextOnly
329-
transformer_args = {"text": {"config": loaded_params}}
341+
transformer_args = {"text": loaded_params}
330342
else:
331343
model_type = ModelType(model_type_name)
332344
transformer_args = {
@@ -420,6 +432,7 @@ def __init__(self, config: ModelArgs) -> None:
420432
super().__init__()
421433
self.config = config
422434
self.model = self.build_model()
435+
self.text_transformer_args = None
423436

424437
def build_model(self) -> nn.Module:
425438
"""
@@ -433,7 +446,10 @@ def build_model(self) -> nn.Module:
433446
modules = {}
434447
for name, module_class in recipe.modules.items():
435448
config_args = self.config.transformer_args[name]
436-
modules[name] = module_class(**config_args)
449+
if module_class == Transformer:
450+
modules[name] = module_class(TransformerArgs.from_params(config_args))
451+
else:
452+
modules[name] = module_class(**config_args)
437453

438454
return recipe.fusion_class(**modules)
439455

@@ -486,6 +502,10 @@ def from_gguf(cls, gguf_path: str, **kwargs):
486502

487503

488504
class TextOnlyModel(Model):
505+
def __init__(self, config: ModelArgs) -> None:
506+
super().__init__(config)
507+
self.text_transformer_args = self.model.config
508+
489509
def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
490510
return self.model(tokens, input_pos)
491511

@@ -548,9 +568,8 @@ def setup_caches(self, max_batch_size, max_seq_length):
548568

549569

550570
class Transformer(nn.Module):
551-
def __init__(self, config: Dict[str, Any]) -> None:
571+
def __init__(self, config: TransformerArgs) -> None:
552572
super().__init__()
553-
config = TransformerArgs.from_params(config)
554573
self.config = config
555574
layers_per_stage = config.n_layers // config.n_stages
556575

@@ -930,11 +949,9 @@ def __init__(self, config, path) -> None:
930949
super().__init__()
931950
self.config = config
932951
self.model_ = exec_lib._load_for_executorch(str(path))
933-
934-
# A hacky way to get the model config from the self.model, making it consistent with Model class
935-
# TODO: remove the hacky way once get rid of model.model
936-
self.model = type('model', (), {'config': self.config})
937952

953+
self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"])
954+
938955
def forward(self, x, input_pos):
939956
# model_.forward expects inputs to be wrapped in a tuple
940957
forward_inputs = (x.to(torch.long), input_pos.to(torch.long))
@@ -948,7 +965,7 @@ def forward(self, x, input_pos):
948965

949966
def setup_caches(self, max_batch_size, max_seq_length):
950967
pass
951-
968+
952969
except:
953970
pass
954971

torchchat/usages/browser.py

Lines changed: 106 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,123 @@
1+
import base64
2+
import logging
13
import time
4+
from pathlib import Path
5+
26
import streamlit as st
7+
8+
logger = logging.getLogger(__name__)
9+
logger.setLevel(logging.DEBUG)
10+
311
from openai import OpenAI
412

513
st.title("torchchat")
614

715
start_state = [
816
{
917
"role": "system",
10-
"content": "You're an assistant. Answer questions directly, be brief, and have fun.",
18+
"content": "You're a helpful assistant - have fun.",
1119
},
1220
{"role": "assistant", "content": "How can I help you?"},
1321
]
1422

23+
st.session_state.uploader_key = 0
24+
25+
26+
def reset_per_message_state():
27+
# Catch all function for anything that should be reset between each message.
28+
_update_uploader_key()
29+
30+
31+
def _update_uploader_key():
32+
# Increment the uploader key to reset the file uploader after each message.
33+
st.session_state.uploader_key = int(time.time())
34+
35+
1536
with st.sidebar:
37+
# API Configuration
38+
api_base_url = st.text_input(
39+
label="API Base URL",
40+
value="http://127.0.0.1:5000/v1",
41+
help="The base URL for the OpenAI API to connect to",
42+
)
43+
44+
st.divider()
45+
temperature = st.slider(
46+
"Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.01
47+
)
48+
1649
response_max_tokens = st.slider(
1750
"Max Response Tokens", min_value=10, max_value=1000, value=250, step=10
1851
)
1952
if st.button("Reset Chat", type="primary"):
2053
st.session_state["messages"] = start_state
2154

55+
image_prompts = st.file_uploader(
56+
"Image Prompts",
57+
type=["jpeg"],
58+
accept_multiple_files=True,
59+
key=st.session_state.uploader_key,
60+
)
61+
62+
for image in image_prompts:
63+
st.image(image)
64+
65+
66+
client = OpenAI(
67+
base_url=api_base_url,
68+
api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
69+
)
70+
2271
if "messages" not in st.session_state:
2372
st.session_state["messages"] = start_state
2473

2574

2675
for msg in st.session_state.messages:
27-
st.chat_message(msg["role"]).write(msg["content"])
76+
with st.chat_message(msg["role"]):
77+
if type(msg["content"]) is list:
78+
for content in msg["content"]:
79+
if content["type"] == "image_url":
80+
extension = (
81+
content["image_url"].split(";base64")[0].split("image/")[1]
82+
)
83+
base64_repr = content["image_url"].split("base64,")[1]
84+
st.image(base64.b64decode(base64_repr))
85+
else:
86+
st.write(content["text"])
87+
elif type(msg["content"]) is dict:
88+
if msg["content"]["type"] == "image_url":
89+
st.image(msg["content"]["image_url"])
90+
else:
91+
st.write(msg["content"]["text"])
92+
elif type(msg["content"]) is str:
93+
st.write(msg["content"])
94+
else:
95+
st.write(f"Unhandled content type: {type(msg['content'])}")
96+
2897

2998
if prompt := st.chat_input():
30-
client = OpenAI(
31-
base_url="http://127.0.0.1:5000/v1",
32-
api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string.
33-
)
99+
user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]}
100+
101+
if image_prompts:
102+
for image_prompt in image_prompts:
103+
extension = Path(image_prompt.name).suffix.strip(".")
104+
image_bytes = image_prompt.getvalue()
105+
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
106+
user_message["content"].append(
107+
{
108+
"type": "image_url",
109+
"image_url": f"data:image/{extension};base64,{base64_encoded}",
110+
}
111+
)
112+
st.session_state.messages.append(user_message)
113+
114+
with st.chat_message("user"):
115+
st.write(prompt)
116+
for img in image_prompts:
117+
st.image(img)
34118

35-
st.session_state.messages.append({"role": "user", "content": prompt})
36-
st.chat_message("user").write(prompt)
119+
image_prompts = None
120+
reset_per_message_state()
37121

38122
with st.chat_message("assistant"), st.status(
39123
"Generating... ", expanded=True
@@ -53,15 +137,20 @@ def get_streamed_completion(completion_generator):
53137
state="complete",
54138
)
55139

56-
response = st.write_stream(
57-
get_streamed_completion(
58-
client.chat.completions.create(
59-
model="llama3",
60-
messages=st.session_state.messages,
61-
max_tokens=response_max_tokens,
62-
stream=True,
140+
try:
141+
response = st.write_stream(
142+
get_streamed_completion(
143+
client.chat.completions.create(
144+
model="llama3",
145+
messages=st.session_state.messages,
146+
max_tokens=response_max_tokens,
147+
temperature=temperature,
148+
stream=True,
149+
)
63150
)
64-
)
65-
)[0]
151+
)[0]
152+
except Exception as e:
153+
response = st.error(f"Error: {e}")
154+
print(e)
66155

67156
st.session_state.messages.append({"role": "assistant", "content": response})

torchchat/usages/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
5959
T = prompt.size(0)
6060
T_new = T + max_new_tokens
6161
if max_seq_length is None:
62-
max_seq_length = min(T_new, model.model.config.block_size)
62+
max_seq_length = min(T_new, model.text_transformer_args.block_size)
6363

6464
device, dtype = prompt.device, prompt.dtype
6565
# create an empty tensor of the expected final shape and

0 commit comments

Comments
 (0)