Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit fff8647

Browse files
committed
unify model construction ppl
1 parent 7708646 commit fff8647

File tree

5 files changed

+43
-57
lines changed

5 files changed

+43
-57
lines changed

torchchat/cli/builder.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,7 @@ def validate_model(
240240

241241
is_tiktoken = self.is_tiktoken
242242
is_sentencepiece = self.is_sentencepiece
243-
text_args = model.config.transformer_args.get("text")
244-
if text_args is None:
245-
# TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo
246-
use_tiktoken = model.config.model_type == ModelType.Flamingo
247-
else:
248-
use_tiktoken = text_args.use_tiktoken
243+
use_tiktoken = model.config.use_tiktoken
249244

250245
if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
251246
raise RuntimeError(

torchchat/generate.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,6 @@
2727

2828
from PIL import Image
2929

30-
# torchtune model definition dependencies
31-
from torchtune.data import Message
32-
from torchtune.generation._generation import sample as tune_sample
33-
from torchtune.models.llama3 import llama3_tokenizer
34-
from torchtune.training import set_default_dtype
35-
3630
from torchchat.cli.builder import (
3731
_initialize_model,
3832
_initialize_tokenizer,
@@ -43,6 +37,12 @@
4337
from torchchat.utils.build_utils import device_sync, set_precision
4438
from torchchat.utils.device_info import get_device_info
4539

40+
# torchtune model definition dependencies
41+
from torchtune.data import Message
42+
from torchtune.generation._generation import sample as tune_sample
43+
from torchtune.models.llama3 import llama3_tokenizer
44+
from torchtune.training import set_default_dtype
45+
4646

4747
class _ChatFormatter(ABC):
4848
def __init__(self, tokenizer):
@@ -790,16 +790,12 @@ def chat(
790790

791791
# This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong
792792
# TODO: unify the max_seq_length config representation.
793-
if generator_args.is_torchtune_model:
794-
max_seq_length = self.model.config.transformer_args.get("text", {}).get(
795-
"max_seq_len", 2048
796-
)
797-
elif generator_args.chat_mode:
798-
if (
799-
max_seq_length := self.model.config.transformer_args.get("text", None)
800-
is None
801-
):
802-
max_seq_length = 2048
793+
text_transformer_args = getattr(self.model.model, "config", None)
794+
max_seq_length = (
795+
text_transformer_args.max_seq_length if text_transformer_args else 2048
796+
)
797+
798+
if generator_args.chat_mode:
803799
print(
804800
f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye"
805801
)
@@ -809,15 +805,9 @@ def chat(
809805
if get_system_prompt == "y" or get_system_prompt == "Y":
810806
self.system_prompt = input("What is your system prompt? \n")
811807

812-
else:
813-
text_transformer_args = self.model.config.transformer_args.get("text", None)
808+
elif not generator_args.is_torchtune_model:
814809
max_seq_length = min(
815-
encoded.size(0) + generator_args.max_new_tokens,
816-
(
817-
text_transformer_args.block_size
818-
if text_transformer_args is not None
819-
else 2048
820-
),
810+
encoded.size(0) + generator_args.max_new_tokens, max_seq_length
821811
)
822812

823813
max_seq_length = (

torchchat/model.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -164,49 +164,49 @@ def from_params(cls, params):
164164
@dataclass
165165
class ModelArgs:
166166
model_type: ModelType
167-
transformer_args: Dict[str, Union[Dict, TransformerArgs]]
167+
transformer_args: Dict[str, Dict[str, Any]]
168+
use_tiktoken: bool
168169

169170
def __init__(
170171
self,
171-
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]],
172+
transformer_args: Dict[str, Dict[str, Any]],
172173
model_type: ModelType = ModelType.TextOnly,
174+
use_tiktoken: bool = False,
173175
) -> None:
174176
self._sanity_check(transformer_args, model_type)
175177

176178
self.model_type = model_type
177-
if isinstance(transformer_args, TransformerArgs):
178-
assert model_type == ModelType.TextOnly
179-
self.transformer_args = {"text": transformer_args}
180-
else:
181-
self.transformer_args = transformer_args
179+
self.transformer_args = transformer_args
180+
181+
# Model-level attributes
182+
self.use_tiktoken = use_tiktoken
182183

183184
def _sanity_check(
184185
self,
185-
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]],
186+
transformer_args: Dict[str, Dict[str, Any]],
186187
model_type: ModelType,
187188
) -> None:
188-
assert isinstance(model_type, ModelType)
189-
assert isinstance(transformer_args, (TransformerArgs, dict))
189+
assert isinstance(model_type, ModelType), model_type
190+
assert isinstance(transformer_args, dict)
190191

191192
@classmethod
192193
def from_params(cls, params_path):
193194
with open(params_path, "r") as f:
194195
loaded_params = json.loads(f.read())
195-
196-
try:
197-
# try to interpret as a single transformer config
198-
transformer_args: Dict[str, TransformerArgs] = {}
199-
transformer_args["text"] = TransformerArgs.from_params(loaded_params)
200-
if (model_type := loaded_params.get("model_type", None)) is None:
201-
model_type = ModelType.TextOnly
202-
203-
except TypeError:
204-
# try to interpret as a dict of transformer configs
205-
model_type = ModelType(loaded_params["model_type"])
196+
197+
if (model_type_name := loaded_params.get("model_type", None)) is None:
198+
# The model params is in the transformer_args format
199+
# set the model_type to TextOnly and reformat the params
200+
model_type = ModelType.TextOnly
201+
transformer_args = {"text": {"config": loaded_params}}
202+
else:
203+
model_type = ModelType(model_type_name)
206204
transformer_args = {
207205
k: v for k, v in loaded_params.items() if k != "model_type"
208206
}
209-
return cls(transformer_args, model_type)
207+
208+
use_tiktoken = loaded_params.get("use_tiktoken", False)
209+
return cls(transformer_args, model_type, use_tiktoken)
210210

211211
@classmethod
212212
def from_table(cls, name: str):
@@ -304,10 +304,8 @@ def build_model(self) -> nn.Module:
304304
recipe = ModelRecipe.get_recipe(self.config.model_type)
305305
modules = {}
306306
for name, module_class in recipe.modules.items():
307-
if isinstance(config_args := self.config.transformer_args[name], dict):
308-
modules[name] = module_class(**config_args)
309-
else:
310-
modules[name] = module_class(config_args)
307+
config_args = self.config.transformer_args[name]
308+
modules[name] = module_class(**config_args)
311309

312310
return recipe.fusion_class(**modules)
313311

@@ -399,8 +397,9 @@ def reset_caches(self):
399397

400398

401399
class Transformer(nn.Module):
402-
def __init__(self, config: TransformerArgs) -> None:
400+
def __init__(self, config: Dict[str, Any]) -> None:
403401
super().__init__()
402+
config = TransformerArgs.from_params(config)
404403
self.config = config
405404
layers_per_stage = config.n_layers // config.n_stages
406405

torchchat/model_params/Meta-Llama-3.1-70B-Tune.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"model_type": "llama3_1",
3+
"use_tiktoken": true,
34
"text": {
45
"vocab_size": 128256,
56
"num_layers": 80,

torchchat/model_params/Meta-Llama-3.1-8B-Tune.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"model_type": "llama3_1",
3+
"use_tiktoken": true,
34
"text": {
45
"vocab_size": 128256,
56
"num_layers": 32,

0 commit comments

Comments
 (0)