Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion distributed/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def apply_tp(
# after we apply TP to the model. Because we don't want to change model code
# when applying TP. We need to have change to ensure KVCache has the correct
# size as k and v.
model.config.transformer_args["text"].n_local_heads = model.config.transformer_args["text"].n_local_heads // tp_mesh.size()
model.text_transformer_args.n_local_heads = model.text_transformer_args.n_local_heads // tp_mesh.size()

# Apply tensor parallelism to every transformer block
for transformer_block in model.layers:
Expand Down
9 changes: 2 additions & 7 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,7 @@ def validate_model(

is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
text_args = model.config.transformer_args.get("text")
if text_args is None:
# TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo
use_tiktoken = model.config.model_type == ModelType.Flamingo
else:
use_tiktoken = text_args.use_tiktoken
use_tiktoken = model.config.use_tiktoken

if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
raise RuntimeError(
Expand Down Expand Up @@ -568,7 +563,7 @@ def _initialize_model(
model.setup_caches(
max_batch_size=1,
max_seq_length=max_seq_length
or model.config.transformer_args["text"].max_seq_length,
or model.text_transformer_args.max_seq_length,
)

model.to(dtype=builder_args.precision)
Expand Down
2 changes: 1 addition & 1 deletion torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def export_for_server(
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
)

seq = Dim("seq", min=1, max=model.config.transformer_args["text"].max_seq_length)
seq = Dim("seq", min=1, max=model.text_transformer_args.max_seq_length)
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}}
else:
Expand Down
38 changes: 14 additions & 24 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@

from PIL import Image

# torchtune model definition dependencies
from torchtune.data import Message
from torchtune.generation._generation import sample as tune_sample
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.training import set_default_dtype

from torchchat.cli.builder import (
_initialize_model,
_initialize_tokenizer,
Expand All @@ -43,6 +37,12 @@
from torchchat.utils.build_utils import device_sync, set_precision
from torchchat.utils.device_info import get_device_info

# torchtune model definition dependencies
from torchtune.data import Message
from torchtune.generation._generation import sample as tune_sample
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.training import set_default_dtype


class _ChatFormatter(ABC):
def __init__(self, tokenizer):
Expand Down Expand Up @@ -795,16 +795,12 @@ def chat(

# This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong
# TODO: unify the max_seq_length config representation.
if generator_args.is_torchtune_model:
max_seq_length = self.model.config.transformer_args.get("text", {}).get(
"max_seq_len", 2048
)
elif generator_args.chat_mode:
if (
max_seq_length := self.model.config.transformer_args.get("text", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your changes are right; just calling out that the old implementation was broken in 26c1d8b

is None
):
max_seq_length = 2048
text_transformer_args = self.model.text_transformer_args
max_seq_length = (
text_transformer_args.max_seq_length if text_transformer_args else 2048
)

if generator_args.chat_mode:
print(
f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye"
)
Expand All @@ -814,15 +810,9 @@ def chat(
if get_system_prompt == "y" or get_system_prompt == "Y":
self.system_prompt = input("What is your system prompt? \n")

else:
text_transformer_args = self.model.config.transformer_args.get("text", None)
elif not generator_args.is_torchtune_model:
max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens,
(
text_transformer_args.block_size
if text_transformer_args is not None
else 2048
),
encoded.size(0) + generator_args.max_new_tokens, max_seq_length
Copy link
Contributor

@Jack-Khuu Jack-Khuu Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is a departure from the original code where the second argument to min is block_size (which represents a different max_seq_length (confusing i know)).

While we want to move away from using the block_size, let's not do it in this diff

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! not sure why this happen, probaly a typo. Will fix it.

)

max_seq_length = (
Expand Down
73 changes: 47 additions & 26 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,50 +163,62 @@ def from_params(cls, params):

@dataclass
class ModelArgs:
"""
A data class to describe the structure of a model.
Attributes:
model_type (ModelType): The type of the model. This attribute is used to categorize the model into different classes.
transformer_args (Dict[str, Dict[str, Any]]): A dictionary containing the parameters for each transformer in the model.
The outer dictionary has transformer names as keys and inner dictionaries as values. Each inner dictionary contains
the parameter names and their corresponding values for the respective transformer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each inner dictionary contains the parameter names and their corresponding values for the respective transformer.

This sounds like the intent of transformer args; why can't we use that instead of Dictp[str, Any]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for unification. this arg takes charge for describing architecture for all models, including tune-backends, chat-backends, and even mix-backends. so we need a unify way to descible how we will set up them.
for chat-backend modules, the inner Dict will be converted into tranformerArg afterwards.

use_tiktoken (bool): A flag indicating whether to use TikToken as the tokenizer for the model.
Note:
It is recommended to use factory functions to create instances of this class instead of directly using the constructor.
"""

model_type: ModelType
transformer_args: Dict[str, Union[Dict, TransformerArgs]]
transformer_args: Dict[str, Dict[str, Any]]
use_tiktoken: bool

def __init__(
self,
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]],
transformer_args: Dict[str, Dict[str, Any]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should find a way to reconcile Dict[str, Any] into a TransformerArgs in a future PR

This makes this work well since we have 3 "cases", but storing/passing around an untyped Dict makes me nervous

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More than agree. My mental model would be creating an abstract class containig essential apis for all module configurations, and for different transformer (e.g. ours, tunes, etc) we have a different implementation. Dict[str, Any] is not a great way.
Let me add some comments in our codebase to highlight that.

model_type: ModelType = ModelType.TextOnly,
use_tiktoken: bool = False,
) -> None:
self._sanity_check(transformer_args, model_type)

self.model_type = model_type
if isinstance(transformer_args, TransformerArgs):
assert model_type == ModelType.TextOnly
self.transformer_args = {"text": transformer_args}
else:
self.transformer_args = transformer_args
self.transformer_args = transformer_args

# Model-level attributes
self.use_tiktoken = use_tiktoken

def _sanity_check(
self,
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]],
transformer_args: Dict[str, Dict[str, Any]],
model_type: ModelType,
) -> None:
assert isinstance(model_type, ModelType)
assert isinstance(transformer_args, (TransformerArgs, dict))
assert isinstance(model_type, ModelType), model_type
assert isinstance(transformer_args, dict)

@classmethod
def from_params(cls, params_path):
with open(params_path, "r") as f:
loaded_params = json.loads(f.read())

try:
# try to interpret as a single transformer config
transformer_args: Dict[str, TransformerArgs] = {}
transformer_args["text"] = TransformerArgs.from_params(loaded_params)
if (model_type := loaded_params.get("model_type", None)) is None:
model_type = ModelType.TextOnly

except TypeError:
# try to interpret as a dict of transformer configs
model_type = ModelType(loaded_params["model_type"])

if (model_type_name := loaded_params.get("model_type", None)) is None:
# The model params is in the transformer_args format
# set the model_type to TextOnly and reformat the params
model_type = ModelType.TextOnly
transformer_args = {"text": loaded_params}
else:
model_type = ModelType(model_type_name)
transformer_args = {
k: v for k, v in loaded_params.items() if k != "model_type"
}
return cls(transformer_args, model_type)

use_tiktoken = loaded_params.get("use_tiktoken", False)
return cls(transformer_args, model_type, use_tiktoken)

@classmethod
def from_table(cls, name: str):
Expand Down Expand Up @@ -292,6 +304,7 @@ def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.model = self.build_model()
self.text_transformer_args = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on this since it is a special case


def build_model(self) -> nn.Module:
"""
Expand All @@ -304,10 +317,11 @@ def build_model(self) -> nn.Module:
recipe = ModelRecipe.get_recipe(self.config.model_type)
modules = {}
for name, module_class in recipe.modules.items():
if isinstance(config_args := self.config.transformer_args[name], dict):
modules[name] = module_class(**config_args)
config_args = self.config.transformer_args[name]
if module_class == Transformer:
modules[name] = module_class(TransformerArgs.from_params(config_args))
else:
modules[name] = module_class(config_args)
modules[name] = module_class(**config_args)

return recipe.fusion_class(**modules)

Expand Down Expand Up @@ -353,6 +367,10 @@ def from_gguf(cls, gguf_path: str, **kwargs):


class TextOnlyModel(Model):
def __init__(self, config: ModelArgs) -> None:
super().__init__(config)
self.text_transformer_args = self.model.config

def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
return self.model(tokens, input_pos)

Expand Down Expand Up @@ -391,6 +409,7 @@ def reset_caches(self):
self.model.reset_caches()



MODEL_TYPE_TO_CLASS = {
ModelType.TextOnly: TextOnlyModel,
ModelType.Flamingo: FlamingoModel,
Expand Down Expand Up @@ -781,6 +800,8 @@ def __init__(self, config, path) -> None:
self.config = config
self.model_ = exec_lib._load_for_executorch(str(path))

self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"])

def forward(self, x, input_pos):
# model_.forward expects inputs to be wrapped in a tuple
forward_inputs = (x.to(torch.long), input_pos.to(torch.long))
Expand All @@ -794,6 +815,6 @@ def forward(self, x, input_pos):

def setup_caches(self, max_batch_size, max_seq_length):
pass

except:
pass
1 change: 1 addition & 0 deletions torchchat/model_params/Meta-Llama-3.1-70B-Tune.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"model_type": "llama3_1",
"use_tiktoken": true,
"text": {
"vocab_size": 128256,
"num_layers": 80,
Expand Down
1 change: 1 addition & 0 deletions torchchat/model_params/Meta-Llama-3.1-8B-Tune.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"model_type": "llama3_1",
"use_tiktoken": true,
"text": {
"vocab_size": 128256,
"num_layers": 32,
Expand Down
2 changes: 1 addition & 1 deletion torchchat/usages/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
T = prompt.size(0)
T_new = T + max_new_tokens
if max_seq_length is None:
max_seq_length = min(T_new, model.config.transformer_args["text"].block_size)
max_seq_length = min(T_new, model.text_transformer_args.block_size)

device, dtype = prompt.device, prompt.dtype
# create an empty tensor of the expected final shape and
Expand Down
12 changes: 7 additions & 5 deletions torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,17 @@ def __init__(self, *args, **kwargs):
"""

super().__init__(*args, **kwargs)
self.max_seq_length = 128
if self.model.config.transformer_args.get("text", None):
self.max_seq_len = (
self.model.config.transformer_args["text"].max_seq_length
try:
self.max_seq_length = (
self.model.text_transformer_args.max_seq_length
+ self.speculative_builder_args.speculate_k
+ 1
if self.draft_model is not None
else self.model.config.transformer_args["text"].max_seq_length
else self.model.text_transformer_args.max_seq_length
)
except:
# can not find max_seq_length in model config, use default value
self.max_seq_length = 128
# The System fingerprint is a unique identifier for the model and its configuration.
self.system_fingerprint = (
f"{self.builder_args.device}_{self.builder_args.precision}"
Expand Down
20 changes: 11 additions & 9 deletions torchchat/utils/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,15 +542,17 @@ def load_model(gguf_file: str) -> torch.nn.Module:
assert arch == "llama", "Only LLaMa models are supported by this converter."

model_args = ModelArgs(
TransformerArgs(
dim=metadata[f"{arch}.embedding_length"],
n_layers=metadata[f"{arch}.block_count"],
n_heads=metadata[f"{arch}.attention.head_count"],
n_local_heads=metadata[f"{arch}.attention.head_count_kv"],
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
hidden_dim=metadata[f"{arch}.feed_forward_length"],
)
{
"text": {
"dim": metadata[f"{arch}.embedding_length"],
"n_layers": metadata[f"{arch}.block_count"],
"n_heads": metadata[f"{arch}.attention.head_count"],
"n_local_heads": metadata[f"{arch}.attention.head_count_kv"],
"vocab_size": len(metadata["tokenizer.ggml.tokens"]),
"norm_eps": metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
"hidden_dim": metadata[f"{arch}.feed_forward_length"],
}
}
)

# TODO: what to do with rope args like
Expand Down
Loading