Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion distributed/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def apply_tp(
# after we apply TP to the model. Because we don't want to change model code
# when applying TP. We need to have change to ensure KVCache has the correct
# size as k and v.
model.config.transformer_args["text"].n_local_heads = model.config.transformer_args["text"].n_local_heads // tp_mesh.size()
model.model.config.n_local_heads = model.model.config.n_local_heads // tp_mesh.size()
Copy link
Contributor

@Jack-Khuu Jack-Khuu Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.model is really hard to reason about... what type is it?

The former was clunky, but legible. I'm not sure about this

Copy link
Contributor

@Jack-Khuu Jack-Khuu Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not happy with "text" either it was not sustainable, especially if the number of modules increases.
It needs fixing, but model.model might not be perfectly there yet, but it's close

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's annoying, i'm 100% agree.
I will remove model.model as soon as I can.


# Apply tensor parallelism to every transformer block
for transformer_block in model.layers:
Expand Down
9 changes: 2 additions & 7 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,7 @@ def validate_model(

is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
text_args = model.config.transformer_args.get("text")
if text_args is None:
# TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo
use_tiktoken = model.config.model_type == ModelType.Flamingo
else:
use_tiktoken = text_args.use_tiktoken
use_tiktoken = model.config.use_tiktoken

if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
raise RuntimeError(
Expand Down Expand Up @@ -568,7 +563,7 @@ def _initialize_model(
model.setup_caches(
max_batch_size=1,
max_seq_length=max_seq_length
or model.config.transformer_args["text"].max_seq_length,
or model.model.config.max_seq_length,
)

model.to(dtype=builder_args.precision)
Expand Down
2 changes: 1 addition & 1 deletion torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def export_for_server(
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
)

seq = Dim("seq", min=1, max=model.config.transformer_args["text"].max_seq_length)
seq = Dim("seq", min=1, max=model.model.config.max_seq_length)
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}}
else:
Expand Down
38 changes: 14 additions & 24 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@

from PIL import Image

# torchtune model definition dependencies
from torchtune.data import Message
from torchtune.generation._generation import sample as tune_sample
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.training import set_default_dtype

from torchchat.cli.builder import (
_initialize_model,
_initialize_tokenizer,
Expand All @@ -43,6 +37,12 @@
from torchchat.utils.build_utils import device_sync, set_precision
from torchchat.utils.device_info import get_device_info

# torchtune model definition dependencies
from torchtune.data import Message
from torchtune.generation._generation import sample as tune_sample
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.training import set_default_dtype


class _ChatFormatter(ABC):
def __init__(self, tokenizer):
Expand Down Expand Up @@ -790,16 +790,12 @@ def chat(

# This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong
# TODO: unify the max_seq_length config representation.
if generator_args.is_torchtune_model:
max_seq_length = self.model.config.transformer_args.get("text", {}).get(
"max_seq_len", 2048
)
elif generator_args.chat_mode:
if (
max_seq_length := self.model.config.transformer_args.get("text", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your changes are right; just calling out that the old implementation was broken in 26c1d8b

is None
):
max_seq_length = 2048
text_transformer_args = getattr(self.model.model, "config", None)
max_seq_length = (
text_transformer_args.max_seq_length if text_transformer_args else 2048
)

if generator_args.chat_mode:
print(
f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye"
)
Expand All @@ -809,15 +805,9 @@ def chat(
if get_system_prompt == "y" or get_system_prompt == "Y":
self.system_prompt = input("What is your system prompt? \n")

else:
text_transformer_args = self.model.config.transformer_args.get("text", None)
elif not generator_args.is_torchtune_model:
max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens,
(
text_transformer_args.block_size
if text_transformer_args is not None
else 2048
),
encoded.size(0) + generator_args.max_new_tokens, max_seq_length
Copy link
Contributor

@Jack-Khuu Jack-Khuu Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is a departure from the original code where the second argument to min is block_size (which represents a different max_seq_length (confusing i know)).

While we want to move away from using the block_size, let's not do it in this diff

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! not sure why this happen, probaly a typo. Will fix it.

)

max_seq_length = (
Expand Down
61 changes: 34 additions & 27 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,49 +164,49 @@ def from_params(cls, params):
@dataclass
class ModelArgs:
model_type: ModelType
transformer_args: Dict[str, Union[Dict, TransformerArgs]]
transformer_args: Dict[str, Dict[str, Any]]
use_tiktoken: bool

def __init__(
self,
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]],
transformer_args: Dict[str, Dict[str, Any]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should find a way to reconcile Dict[str, Any] into a TransformerArgs in a future PR

This makes this work well since we have 3 "cases", but storing/passing around an untyped Dict makes me nervous

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More than agree. My mental model would be creating an abstract class containig essential apis for all module configurations, and for different transformer (e.g. ours, tunes, etc) we have a different implementation. Dict[str, Any] is not a great way.
Let me add some comments in our codebase to highlight that.

model_type: ModelType = ModelType.TextOnly,
use_tiktoken: bool = False,
) -> None:
self._sanity_check(transformer_args, model_type)

self.model_type = model_type
if isinstance(transformer_args, TransformerArgs):
assert model_type == ModelType.TextOnly
self.transformer_args = {"text": transformer_args}
else:
self.transformer_args = transformer_args
self.transformer_args = transformer_args

# Model-level attributes
self.use_tiktoken = use_tiktoken

def _sanity_check(
self,
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]],
transformer_args: Dict[str, Dict[str, Any]],
model_type: ModelType,
) -> None:
assert isinstance(model_type, ModelType)
assert isinstance(transformer_args, (TransformerArgs, dict))
assert isinstance(model_type, ModelType), model_type
assert isinstance(transformer_args, dict)

@classmethod
def from_params(cls, params_path):
with open(params_path, "r") as f:
loaded_params = json.loads(f.read())

try:
# try to interpret as a single transformer config
transformer_args: Dict[str, TransformerArgs] = {}
transformer_args["text"] = TransformerArgs.from_params(loaded_params)
if (model_type := loaded_params.get("model_type", None)) is None:
model_type = ModelType.TextOnly

except TypeError:
# try to interpret as a dict of transformer configs
model_type = ModelType(loaded_params["model_type"])

if (model_type_name := loaded_params.get("model_type", None)) is None:
# The model params is in the transformer_args format
# set the model_type to TextOnly and reformat the params
model_type = ModelType.TextOnly
transformer_args = {"text": {"config": loaded_params}}
else:
model_type = ModelType(model_type_name)
transformer_args = {
k: v for k, v in loaded_params.items() if k != "model_type"
}
return cls(transformer_args, model_type)

use_tiktoken = loaded_params.get("use_tiktoken", False)
return cls(transformer_args, model_type, use_tiktoken)

@classmethod
def from_table(cls, name: str):
Expand Down Expand Up @@ -304,10 +304,8 @@ def build_model(self) -> nn.Module:
recipe = ModelRecipe.get_recipe(self.config.model_type)
modules = {}
for name, module_class in recipe.modules.items():
if isinstance(config_args := self.config.transformer_args[name], dict):
modules[name] = module_class(**config_args)
else:
modules[name] = module_class(config_args)
config_args = self.config.transformer_args[name]
modules[name] = module_class(**config_args)

return recipe.fusion_class(**modules)

Expand Down Expand Up @@ -399,8 +397,9 @@ def reset_caches(self):


class Transformer(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
def __init__(self, config: Dict[str, Any]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan of this one, Transformer taking TransformerArgs is the most intuitive set up and matches the other classes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. Bring it back

super().__init__()
config = TransformerArgs.from_params(config)
self.config = config
layers_per_stage = config.n_layers // config.n_stages

Expand Down Expand Up @@ -780,6 +779,14 @@ def __init__(self, config, path) -> None:
super().__init__()
self.config = config
self.model_ = exec_lib._load_for_executorch(str(path))

# A hacky way to get the model config from the self.model, making it consistent with Model class
# TODO: remove the hacky way once get rid of model.model
try:
text_transformer_config = TransformerArgs.from_params(self.config.transformer_args["text"])
except:
text_transformer_config = None
self.model = type('model', (), {'config': text_transformer_config})

def forward(self, x, input_pos):
# model_.forward expects inputs to be wrapped in a tuple
Expand Down
1 change: 1 addition & 0 deletions torchchat/model_params/Meta-Llama-3.1-70B-Tune.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"model_type": "llama3_1",
"use_tiktoken": true,
"text": {
"vocab_size": 128256,
"num_layers": 80,
Expand Down
1 change: 1 addition & 0 deletions torchchat/model_params/Meta-Llama-3.1-8B-Tune.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"model_type": "llama3_1",
"use_tiktoken": true,
"text": {
"vocab_size": 128256,
"num_layers": 32,
Expand Down
2 changes: 1 addition & 1 deletion torchchat/usages/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
T = prompt.size(0)
T_new = T + max_new_tokens
if max_seq_length is None:
max_seq_length = min(T_new, model.config.transformer_args["text"].block_size)
max_seq_length = min(T_new, model.model.config.block_size)

device, dtype = prompt.device, prompt.dtype
# create an empty tensor of the expected final shape and
Expand Down
4 changes: 2 additions & 2 deletions torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@ def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)
self.max_seq_length = (
self.model.config.transformer_args["text"].max_seq_length
self.model.model.config.max_seq_length
+ self.speculative_builder_args.speculate_k
+ 1
if self.draft_model is not None
else self.model.config.transformer_args["text"].max_seq_length
else self.model.model.config.max_seq_length
)
# The System fingerprint is a unique identifier for the model and its configuration.
self.system_fingerprint = (
Expand Down
22 changes: 13 additions & 9 deletions torchchat/utils/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,15 +542,19 @@ def load_model(gguf_file: str) -> torch.nn.Module:
assert arch == "llama", "Only LLaMa models are supported by this converter."

model_args = ModelArgs(
TransformerArgs(
dim=metadata[f"{arch}.embedding_length"],
n_layers=metadata[f"{arch}.block_count"],
n_heads=metadata[f"{arch}.attention.head_count"],
n_local_heads=metadata[f"{arch}.attention.head_count_kv"],
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
hidden_dim=metadata[f"{arch}.feed_forward_length"],
)
{
"text": {
"config": {
"dim": metadata[f"{arch}.embedding_length"],
"n_layers": metadata[f"{arch}.block_count"],
"n_heads": metadata[f"{arch}.attention.head_count"],
"n_local_heads": metadata[f"{arch}.attention.head_count_kv"],
"vocab_size": len(metadata["tokenizer.ggml.tokens"]),
"norm_eps": metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
"hidden_dim": metadata[f"{arch}.feed_forward_length"],
}
}
}
)

# TODO: what to do with rope args like
Expand Down
Loading