diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index 5f358865d..0b1dca4cf 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -62,7 +62,7 @@ def apply_tp( # after we apply TP to the model. Because we don't want to change model code # when applying TP. We need to have change to ensure KVCache has the correct # size as k and v. - model.config.transformer_args["text"].n_local_heads = model.config.transformer_args["text"].n_local_heads // tp_mesh.size() + model.text_transformer_args.n_local_heads = model.text_transformer_args.n_local_heads // tp_mesh.size() # Apply tensor parallelism to every transformer block for transformer_block in model.layers: diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index f600fe7f2..2933edade 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -240,12 +240,7 @@ def validate_model( is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece - text_args = model.config.transformer_args.get("text") - if text_args is None: - # TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo - use_tiktoken = model.config.model_type == ModelType.Flamingo - else: - use_tiktoken = text_args.use_tiktoken + use_tiktoken = model.config.use_tiktoken if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): raise RuntimeError( @@ -568,7 +563,7 @@ def _initialize_model( model.setup_caches( max_batch_size=1, max_seq_length=max_seq_length - or model.config.transformer_args["text"].max_seq_length, + or model.text_transformer_args.max_seq_length, ) model.to(dtype=builder_args.precision) diff --git a/torchchat/export.py b/torchchat/export.py index efc791dc8..affb8b871 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -54,7 +54,7 @@ def export_for_server( torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), ) - seq = Dim("seq", min=1, max=model.config.transformer_args["text"].max_seq_length) + seq = Dim("seq", min=1, max=model.text_transformer_args.max_seq_length) # Specify that the first dimension of each input is that batch size dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}} else: diff --git a/torchchat/generate.py b/torchchat/generate.py index 30490d396..4c8b02e91 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -27,12 +27,6 @@ from PIL import Image -# torchtune model definition dependencies -from torchtune.data import Message -from torchtune.generation._generation import sample as tune_sample -from torchtune.models.llama3 import llama3_tokenizer -from torchtune.training import set_default_dtype - from torchchat.cli.builder import ( _initialize_model, _initialize_tokenizer, @@ -43,6 +37,12 @@ from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info +# torchtune model definition dependencies +from torchtune.data import Message +from torchtune.generation._generation import sample as tune_sample +from torchtune.models.llama3 import llama3_tokenizer +from torchtune.training import set_default_dtype + class _ChatFormatter(ABC): def __init__(self, tokenizer): @@ -795,16 +795,12 @@ def chat( # This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong # TODO: unify the max_seq_length config representation. - if generator_args.is_torchtune_model: - max_seq_length = self.model.config.transformer_args.get("text", {}).get( - "max_seq_len", 2048 - ) - elif generator_args.chat_mode: - if ( - max_seq_length := self.model.config.transformer_args.get("text", None) - is None - ): - max_seq_length = 2048 + text_transformer_args = self.model.text_transformer_args + max_seq_length = ( + text_transformer_args.max_seq_length if text_transformer_args else 2048 + ) + + if generator_args.chat_mode: print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" ) @@ -814,8 +810,7 @@ def chat( if get_system_prompt == "y" or get_system_prompt == "Y": self.system_prompt = input("What is your system prompt? \n") - else: - text_transformer_args = self.model.config.transformer_args.get("text", None) + elif not generator_args.is_torchtune_model: max_seq_length = min( encoded.size(0) + generator_args.max_new_tokens, ( diff --git a/torchchat/model.py b/torchchat/model.py index 10c036a36..f936970dd 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -163,50 +163,64 @@ def from_params(cls, params): @dataclass class ModelArgs: + """ + A data class to describe the structure of a model. + Attributes: + model_type (ModelType): The type of the model. This attribute is used to categorize the model into different classes. + transformer_args (Dict[str, Dict[str, Any]]): A dictionary containing the parameters for each transformer in the model. + The outer dictionary has transformer names as keys and inner dictionaries as values. Each inner dictionary contains + the parameter names and their corresponding values for the respective transformer. + TODO: econcile Dict[str, Any] into tranformer-arg-family classes in future PRs. + + use_tiktoken (bool): A flag indicating whether to use TikToken as the tokenizer for the model. + Note: + It is recommended to use factory functions to create instances of this class instead of directly using the constructor. + """ + model_type: ModelType - transformer_args: Dict[str, Union[Dict, TransformerArgs]] + transformer_args: Dict[str, Dict[str, Any]] + use_tiktoken: bool def __init__( self, - transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], + transformer_args: Dict[str, Dict[str, Any]], model_type: ModelType = ModelType.TextOnly, + use_tiktoken: bool = False, ) -> None: self._sanity_check(transformer_args, model_type) self.model_type = model_type - if isinstance(transformer_args, TransformerArgs): - assert model_type == ModelType.TextOnly - self.transformer_args = {"text": transformer_args} - else: - self.transformer_args = transformer_args + self.transformer_args = transformer_args + + # Model-level attributes + self.use_tiktoken = use_tiktoken def _sanity_check( self, - transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], + transformer_args: Dict[str, Dict[str, Any]], model_type: ModelType, ) -> None: - assert isinstance(model_type, ModelType) - assert isinstance(transformer_args, (TransformerArgs, dict)) + assert isinstance(model_type, ModelType), model_type + assert isinstance(transformer_args, dict) @classmethod def from_params(cls, params_path): with open(params_path, "r") as f: loaded_params = json.loads(f.read()) - - try: - # try to interpret as a single transformer config - transformer_args: Dict[str, TransformerArgs] = {} - transformer_args["text"] = TransformerArgs.from_params(loaded_params) - if (model_type := loaded_params.get("model_type", None)) is None: - model_type = ModelType.TextOnly - - except TypeError: - # try to interpret as a dict of transformer configs - model_type = ModelType(loaded_params["model_type"]) + + if (model_type_name := loaded_params.get("model_type", None)) is None: + # The model params is in the transformer_args format + # set the model_type to TextOnly and reformat the params + model_type = ModelType.TextOnly + transformer_args = {"text": loaded_params} + else: + model_type = ModelType(model_type_name) transformer_args = { k: v for k, v in loaded_params.items() if k != "model_type" } - return cls(transformer_args, model_type) + + use_tiktoken = loaded_params.get("use_tiktoken", False) + return cls(transformer_args, model_type, use_tiktoken) @classmethod def from_table(cls, name: str): @@ -293,6 +307,10 @@ def __init__(self, config: ModelArgs) -> None: self.config = config self.model = self.build_model() + # text_transformer_args represents the args for the text transformer in the model. + # It should be assigned in the actual model implementation, if any. + self.text_transformer_args = None + def build_model(self) -> nn.Module: """ Builds a model based on the provided configuration. @@ -304,10 +322,11 @@ def build_model(self) -> nn.Module: recipe = ModelRecipe.get_recipe(self.config.model_type) modules = {} for name, module_class in recipe.modules.items(): - if isinstance(config_args := self.config.transformer_args[name], dict): - modules[name] = module_class(**config_args) + config_args = self.config.transformer_args[name] + if module_class == Transformer: + modules[name] = module_class(TransformerArgs.from_params(config_args)) else: - modules[name] = module_class(config_args) + modules[name] = module_class(**config_args) return recipe.fusion_class(**modules) @@ -353,6 +372,10 @@ def from_gguf(cls, gguf_path: str, **kwargs): class TextOnlyModel(Model): + def __init__(self, config: ModelArgs) -> None: + super().__init__(config) + self.text_transformer_args = self.model.config + def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: return self.model(tokens, input_pos) @@ -391,6 +414,7 @@ def reset_caches(self): self.model.reset_caches() + MODEL_TYPE_TO_CLASS = { ModelType.TextOnly: TextOnlyModel, ModelType.Flamingo: FlamingoModel, @@ -781,6 +805,8 @@ def __init__(self, config, path) -> None: self.config = config self.model_ = exec_lib._load_for_executorch(str(path)) + self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"]) + def forward(self, x, input_pos): # model_.forward expects inputs to be wrapped in a tuple forward_inputs = (x.to(torch.long), input_pos.to(torch.long)) @@ -794,6 +820,6 @@ def forward(self, x, input_pos): def setup_caches(self, max_batch_size, max_seq_length): pass - + except: pass diff --git a/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json b/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json index c59961c63..3c611a753 100644 --- a/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json +++ b/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json @@ -1,5 +1,6 @@ { "model_type": "llama3_1", + "use_tiktoken": true, "text": { "vocab_size": 128256, "num_layers": 80, diff --git a/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json b/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json index e9ded77bd..adc9e4e8e 100644 --- a/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json +++ b/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json @@ -1,5 +1,6 @@ { "model_type": "llama3_1", + "use_tiktoken": true, "text": { "vocab_size": 128256, "num_layers": 32, diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index f8ac6fbe1..5993c3781 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -59,7 +59,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( T = prompt.size(0) T_new = T + max_new_tokens if max_seq_length is None: - max_seq_length = min(T_new, model.config.transformer_args["text"].block_size) + max_seq_length = min(T_new, model.text_transformer_args.block_size) device, dtype = prompt.device, prompt.dtype # create an empty tensor of the expected final shape and diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 269e89f54..6381e8112 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -282,15 +282,17 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - self.max_seq_length = 128 - if self.model.config.transformer_args.get("text", None): - self.max_seq_len = ( - self.model.config.transformer_args["text"].max_seq_length + try: + self.max_seq_length = ( + self.model.text_transformer_args.max_seq_length + self.speculative_builder_args.speculate_k + 1 if self.draft_model is not None - else self.model.config.transformer_args["text"].max_seq_length + else self.model.text_transformer_args.max_seq_length ) + except: + # can not find max_seq_length in model config, use default value + self.max_seq_length = 128 # The System fingerprint is a unique identifier for the model and its configuration. self.system_fingerprint = ( f"{self.builder_args.device}_{self.builder_args.precision}" diff --git a/torchchat/utils/gguf_loader.py b/torchchat/utils/gguf_loader.py index c7b931dae..309ff807c 100644 --- a/torchchat/utils/gguf_loader.py +++ b/torchchat/utils/gguf_loader.py @@ -542,15 +542,17 @@ def load_model(gguf_file: str) -> torch.nn.Module: assert arch == "llama", "Only LLaMa models are supported by this converter." model_args = ModelArgs( - TransformerArgs( - dim=metadata[f"{arch}.embedding_length"], - n_layers=metadata[f"{arch}.block_count"], - n_heads=metadata[f"{arch}.attention.head_count"], - n_local_heads=metadata[f"{arch}.attention.head_count_kv"], - vocab_size=len(metadata["tokenizer.ggml.tokens"]), - norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"], - hidden_dim=metadata[f"{arch}.feed_forward_length"], - ) + { + "text": { + "dim": metadata[f"{arch}.embedding_length"], + "n_layers": metadata[f"{arch}.block_count"], + "n_heads": metadata[f"{arch}.attention.head_count"], + "n_local_heads": metadata[f"{arch}.attention.head_count_kv"], + "vocab_size": len(metadata["tokenizer.ggml.tokens"]), + "norm_eps": metadata[f"{arch}.attention.layer_norm_rms_epsilon"], + "hidden_dim": metadata[f"{arch}.feed_forward_length"], + } + } ) # TODO: what to do with rope args like