diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 5112fa8e1..f42a20e22 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -458,7 +458,6 @@ jobs: pip3 list python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")' python3 -c 'import torchvision;print(f"torchvision: {torchvision.__version__, torchvision.version.git_version}")' - python3 -c 'import torchaudio;print(f"torchaudio: {torchaudio.__version__, torchaudio.version.git_version}")' cd ../.. echo "Inside: ${PWD}" diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 558a267af..f600fe7f2 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -15,15 +15,26 @@ import torch._dynamo.config import torch._inductor.config import torch.nn as nn + +try: + from _torchchat_test_script import flamingo_meta_to_tune +except ImportError: + pass + from distributed import ( init_distributed, launch_distributed, ParallelDims, parallelize_llama, ) + from torch.distributed.device_mesh import DeviceMesh -from torchchat.model import Model +from torchtune.models.convert_weights import meta_to_tune + +from torchtune.training import set_default_dtype + +from torchchat.model import Model, ModelType from torchchat.model_config.model_config import resolve_model_config from torchchat.utils.build_utils import ( @@ -35,10 +46,6 @@ from torchchat.utils.measure_time import measure_time from torchchat.utils.quantize import quantize_model -from torchtune.models.convert_weights import meta_to_tune - - - @dataclass class BuilderArgs: @@ -143,7 +150,6 @@ def from_args(cls, args): # -> BuilderArgs: if "chat" in path_basename or "instruct" in path_basename: is_chat_model = True - output_pte_path = getattr(args, "output_pte_path", None) output_dso_path = getattr(args, "output_dso_path", None) if output_pte_path and args.dtype.startswith("fast"): @@ -234,7 +240,12 @@ def validate_model( is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece - use_tiktoken = model.config.transformer_args["text"].use_tiktoken + text_args = model.config.transformer_args.get("text") + if text_args is None: + # TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo + use_tiktoken = model.config.model_type == ModelType.Flamingo + else: + use_tiktoken = text_args.use_tiktoken if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): raise RuntimeError( @@ -266,7 +277,9 @@ def from_args(cls, args): # -> TokenizerArgs: raise RuntimeError("cannot find tokenizer model") if not tokenizer_path.is_file(): - raise RuntimeError(f"did not find tokenizer at {tokenizer_path}") + raise RuntimeError( + f"did not find tokenizer at {os.path.abspath(tokenizer_path)}" + ) return cls( tokenizer_path=tokenizer_path, @@ -335,7 +348,9 @@ def _load_model_default(builder_args, only_config=False): if builder_args.params_table and builder_args.params_table.endswith("Tune"): print("Loading Tune checkpoint") - meta_checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True) + meta_checkpoint = torch.load( + str(builder_args.checkpoint_path), mmap=True, weights_only=True + ) checkpoint = meta_to_tune(meta_checkpoint) elif builder_args.checkpoint_dir is not None: # Load multiple checkpoint; ignore the single path. @@ -372,8 +387,17 @@ def _load_model_default(builder_args, only_config=False): if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): checkpoint = checkpoint["model"] - checkpoint = {"model." + k: v for k, v in checkpoint.items()} - model.load_state_dict(checkpoint, assign=True, strict=True) + if model.config.model_type == ModelType.Flamingo: + # TODO: Refactor this. For now, overwrite the model with model loaded from params_path + with set_default_dtype(builder_args.precision), torch.device( + builder_args.device + ): + model = Model.from_params(builder_args.params_path) + state_dict = flamingo_meta_to_tune(checkpoint) + model.model.load_state_dict(state_dict) + else: + checkpoint = {"model." + k: v for k, v in checkpoint.items()} + model.load_state_dict(checkpoint, assign=True, strict=True) return model diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 80463a37b..170753032 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -46,7 +46,7 @@ def check_args(args, verb: str) -> None: # different semantics. if ( verb not in INVENTORY_VERBS - and args.model + and getattr(args, "model", None) and not is_model_downloaded(args.model, args.model_directory) ): download_and_convert(args.model, args.model_directory, args.hf_token) @@ -320,6 +320,13 @@ def _add_generation_args(parser, verb: str) -> None: help="Number of samples", ) + generator_parser.add_argument( + "--image-prompts", + nargs="+", + type=str, + default=None, + help="Paths to image files used as image prompts for multimodal models. Currently, 1 image input is supported.", + ) generator_parser.add_argument( "--chat", action="store_true", diff --git a/torchchat/generate.py b/torchchat/generate.py index 384d57b96..cea9531e8 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -12,21 +12,34 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from os import PathLike from pathlib import Path -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch import torch._dynamo.config import torch._inductor.config +try: + from _torchchat_test_script import flamingo_transform, padded_collate +except ImportError: + pass + +from PIL import Image + +# torchtune model definition dependencies +from torchtune.data import Message +from torchtune.generation._generation import sample as tune_sample +from torchtune.models.llama3 import llama3_tokenizer +from torchtune.training import set_default_dtype + from torchchat.cli.builder import ( _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs, ) -from torchchat.cli.cli import add_arguments_for_verb, arg_init, check_args -from torchchat.model import Model +from torchchat.model import Model, ModelType from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info @@ -56,10 +69,18 @@ def encode_header(self, role) -> List[int]: return tokens def encode_message(self, message) -> List[int]: - tokens = self.encode_header(message["role"]) - tokens.extend( - self.tokenizer.encode(message["content"].strip(), bos=False, eos=False) - ) + tokens = self.encode_header(message["role"]) + if type(message["content"]) is str: + tokens.extend( + self.tokenizer.encode(message["content"], bos=False, eos=False) + ) + elif type(message["content"]) is list: + for content in message["content"]: + if content["type"] == "text": + tokens.extend( + self.tokenizer.encode(content["text"], bos=False, eos=False) + ) + tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) return tokens @@ -105,6 +126,9 @@ class GeneratorArgs: None # When passed into the Generator, this will be used as the system prompt ) encoded_prompt: Optional[torch.Tensor] = None + image_prompts: Optional[Sequence[Union[str, PathLike, bytes]]] = ( + None # string or Path to an image file or the raw base64 bytes of an image + ) chat_mode: bool = False gui_mode: bool = False num_samples: int = 1 @@ -148,9 +172,15 @@ def from_args(cls, args): pte_path = getattr(args, "pte_path", None) sequential_prefill = args.sequential_prefill or bool(dso_path) or bool(pte_path) + # Validate that all image prompts exist before expensive model load + if image_prompts := getattr(args, "image_prompts", None): + if not all(os.path.exists(image_prompt) for image_prompt in image_prompts): + raise RuntimeError(f"Image prompt {image_prompt} does not exist") + return cls( prompt=getattr(args, "prompt", ""), encoded_prompt=None, + image_prompts=image_prompts, chat_mode=args.chat, gui_mode=args.gui, num_samples=getattr(args, "num_samples", 1), @@ -189,7 +219,9 @@ def __init__( quantize: bool, draft_quantize: bool, ): - torch._inductor.config.coordinate_descent_tuning = False if builder_args.device == "cpu" else True + torch._inductor.config.coordinate_descent_tuning = ( + False if builder_args.device == "cpu" else True + ) torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future @@ -241,15 +273,14 @@ def __init__( # Piggy backing off of this flag then for now to identify llama3 # without prompting user. self.is_llama3_model = self.tokenizer_args.is_tiktoken - if generator_args.chat_mode and self.is_llama3_model: - logging.debug( - "Llama3 model detected in chat mode. Using updated sentence schemas" - ) - self.chat_formatter = ( - Llama3ChatFormatter(self.tokenizer) - if self.is_llama3_model - else Llama2ChatFormatter(self.tokenizer) - ) + if self.is_llama3_model: + self.chat_formatter = Llama3ChatFormatter(self.tokenizer) + if generator_args.chat_mode: + logging.debug( + "Llama3 model detected in chat mode. Using updated sentence schemas" + ) + else: + self.chat_formatter = Llama2ChatFormatter(self.tokenizer) self.builder_args.setup_caches = False self.model = _initialize_model(self.builder_args, self.quantize, self.tokenizer) @@ -300,9 +331,9 @@ def sample( self, logits, need_probs: bool, - temperature: float = 1.0, + temperature: float = 0, top_k: Optional[int] = None, - ): + ): if temperature == 0 and not need_probs: _, idx_next = torch.topk(logits[0, -1], k=1, dim=-1) return (idx_next, None) @@ -315,6 +346,7 @@ def prefill( model: Model, x: torch.Tensor, input_pos: torch.Tensor, + batch: Optional[Dict[str, Any]] = None, # Inputs for multimodal models *, sequential_prefill=True, **sampling_kwargs, @@ -323,7 +355,11 @@ def prefill( width = x.size(1) assert input_pos.size(0) == width - if sequential_prefill: + if batch is not None: + # TODO: Verify sequential prefill works with multimodal models + logits = model(**batch)[:, -1] + return tune_sample(logits, 0, 500) + elif sequential_prefill: for i in range(width): x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) # logging.debug(f" x: {x_sliced}, input_pos: {ip_sliced}") @@ -342,11 +378,16 @@ def decode_one_token( x: torch.Tensor, input_pos: torch.Tensor, need_probs: bool, + batch: Optional[Dict[str, Any]] = None, # Inputs for multimodal models **sampling_kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # input_pos: [B, 1] assert input_pos.shape[-1] == 1 - logits = model(x.view(1, -1), input_pos) + if model.config.model_type == ModelType.Flamingo and batch is not None: + x = x.view(1, -1) + logits = model(x, encoder_mask=batch["encoder_mask"][:, -1:]) + else: + logits = model(x.view(1, -1), input_pos) # print(f"x: {x},\n input_pos: {input_pos}\n") return self.sample(logits, need_probs=need_probs, **sampling_kwargs) @@ -363,6 +404,7 @@ def decode_n_tokens( input_pos: torch.Tensor, num_new_tokens: int, need_probs: bool, + batch=Optional[Dict[str, Any]], # Inputs for multimodal models callback=lambda _: _, eos_token_id: int = 2, eot_id: Optional[int] = None, @@ -381,6 +423,7 @@ def decode_n_tokens( model, out_token, input_pos, + batch=batch, need_probs=need_probs, **sampling_kwargs, ) @@ -400,7 +443,12 @@ def decode_n_tokens( ): encountered_eos = True final_token, next_prob = self.decode_one_token( - model, cur_token, input_pos, need_probs, **sampling_kwargs + model, + cur_token, + input_pos, + need_probs, + batch=batch, + **sampling_kwargs, ) input_pos += 1 break @@ -413,7 +461,12 @@ def decode_n_tokens( ) new_tokens.append(eos_token.clone()) eos_token, next_prob = self.decode_one_token( - model, eos_token.view(1, -1), input_pos, need_probs, **sampling_kwargs + model, + eos_token.view(1, -1), + input_pos, + need_probs, + batch=batch, + **sampling_kwargs, ) input_pos += 1 yield eos_token.clone(), ( @@ -432,6 +485,7 @@ def speculative_decode( cur_token: torch.Tensor, input_pos: int, speculate_k: int, + batch: Optional[Dict[str, Any]] = None, # Inputs for multimodal models **sampling_kwargs, ) -> torch.Tensor: # draft model inference sequentially @@ -444,6 +498,7 @@ def speculative_decode( cur_token, orig_input_pos.clone(), speculate_k, + batch=batch, need_probs=True, **sampling_kwargs, ) @@ -497,6 +552,9 @@ def generate( max_new_tokens: int, *, chat_mode: bool, + batch: Optional[ + Dict[str, Any] + ] = None, # List of Image prompt tensors for multimodal models start_pos: int = 0, draft_model: Model, speculate_k: Optional[int] = 8, @@ -517,6 +575,8 @@ def generate( # create an empty tensor of the expected final shape and # fill in the current tokens + if len(prompt.shape) > 1: + prompt = prompt.squeeze(0) T = prompt.size(0) max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T) T_new = T + max_new_tokens @@ -524,20 +584,26 @@ def generate( if start_pos == 0: model = model.to(device=device) with torch.device(device): - if self.is_torchtune_model: + if ( + self.is_torchtune_model + or self.model.config.model_type == ModelType.Flamingo + ): model.setup_caches(max_batch_size=1, dtype=self.dtype) else: model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) if is_speculative and draft_model is not model: draft_model.setup_caches( - max_batch_size=1, max_seq_length=max_seq_length + max_batch_size=1, + max_seq_length=max_seq_length, ) + if model.config.model_type == ModelType.Flamingo: + model.reset_caches() # create an empty tensor of the expected final shape and # fill in the current tokens empty = torch.empty(T_new, dtype=dtype, device=device) empty[:T] = prompt - seq = empty + input_pos = torch.arange( start_pos, T + start_pos, device=device, dtype=torch.int ) @@ -547,6 +613,7 @@ def generate( model, prompt.view(1, -1), input_pos, + batch=batch, sequential_prefill=sequential_prefill, **sampling_kwargs, ) @@ -558,9 +625,9 @@ def generate( sequential_prefill=sequential_prefill, **sampling_kwargs, ) + time_to_first_token = time.perf_counter() - prefill_t0 yield None, {"time_to_first_token": time_to_first_token} - seq[T] = next_token # max_new_tokens <= 2 means we are effectively not calling decode_n_tokens(). callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2) @@ -582,12 +649,12 @@ def generate( cur_token, input_pos, speculate_k, + batch=batch, **sampling_kwargs, ) accept_counts[len(next_tokens) - 1] += 1 num_added = min(T_new - input_pos - 1, len(next_tokens)) - seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[:num_added] for token in next_tokens[:num_added,]: callback(token) yield token, None @@ -600,6 +667,7 @@ def generate( next_token, input_pos, max_new_tokens - 1, + batch=batch, callback=callback, need_probs=False, eos_token_id=self.tokenizer.eos_id() if self.tokenizer else 2, @@ -610,14 +678,9 @@ def generate( ), **sampling_kwargs, ): - generated_tokens.append(generated_token) + generated_tokens.append(generated_token.view(-1)) yield generated_token, None - seq[T + 1 : T + 1 + len(generated_tokens)] = torch.cat(generated_tokens) - seq = seq[ - : T + 1 + len(generated_tokens) - ] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated. - generate_stats = { "accept_counts": accept_counts, } @@ -655,10 +718,36 @@ def chat( ): if generator_args.chat_mode: print("Starting Interactive Chat") - encoded = self.encode_tokens( - generator_args.prompt, bos=True, device=self.builder_args.device - ) - logging.debug(encoded) + + if generator_args.image_prompts is not None: + print("Image prompts", generator_args.image_prompts) + + messages = [ + Message( + role="user", + content=[ + {"type": "image"}, + {"type": "text", "content": generator_args.prompt}, + ], + eot=True, + ), + Message(role="assistant", content=""), + ] + + images = [Image.open(generator_args.image_prompts[0])] + transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) + + data = transform({"images": images, "messages": messages}, inference=True) + batch = padded_collate([data], self.builder_args.device) + batch.pop("mask") + encoded = batch["tokens"] + + else: + encoded = self.encode_tokens( + generator_args.prompt, bos=True, device=self.builder_args.device + ) + logging.debug(encoded) + batch = None model_size = sum( [ @@ -692,7 +781,9 @@ def chat( ) if generator_args.compile_prefill: - self.prefill = torch.compile(self.prefill, fullgraph=True, dynamic=True, **kwargs) + self.prefill = torch.compile( + self.prefill, fullgraph=True, dynamic=True, **kwargs + ) self.system_prompt = None # Set up our max_seq_length @@ -700,9 +791,15 @@ def chat( # This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong # TODO: unify the max_seq_length config representation. if generator_args.is_torchtune_model: - max_seq_length = self.model.config.transformer_args["text"]["max_seq_len"] + max_seq_length = self.model.config.transformer_args.get("text", {}).get( + "max_seq_len", 2048 + ) elif generator_args.chat_mode: - max_seq_length = self.model.config.transformer_args["text"].max_seq_length + if ( + max_seq_length := self.model.config.transformer_args.get("text", None) + is None + ): + max_seq_length = 2048 print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" ) @@ -713,9 +810,14 @@ def chat( self.system_prompt = input("What is your system prompt? \n") else: + text_transformer_args = self.model.config.transformer_args.get("text", None) max_seq_length = min( encoded.size(0) + generator_args.max_new_tokens, - self.model.config.transformer_args["text"].block_size, + ( + text_transformer_args.block_size + if text_transformer_args is not None + else 2048 + ), ) max_seq_length = ( @@ -772,9 +874,7 @@ def chat( encoded = self.chat_formatter.encode_message( {"role": "user", "content": prompt} ) - encoded.extend( - self.chat_formatter.encode_header("assistant") - ) + encoded.extend(self.chat_formatter.encode_header("assistant")) encoded = torch.tensor( encoded, dtype=torch.int, device=self.builder_args.device ) @@ -809,6 +909,7 @@ def callback(x, *, done_generating=False): if self.profile: from torch._inductor import config as inductor_config + torch._inductor.config.profiler_mark_wrapper_call = True torch._inductor.config.cpp.enable_kernel_profile = True if (i != generator_args.num_samples - 1 or not self.profile) or ( @@ -830,6 +931,7 @@ def callback(x, *, done_generating=False): draft_model=self.draft_model, speculate_k=generator_args.speculate_k, chat_mode=generator_args.chat_mode, + batch=batch, callback=callback, temperature=generator_args.temperature, top_k=generator_args.top_k, diff --git a/torchchat/model.py b/torchchat/model.py index 2cba9a032..10c036a36 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -6,19 +6,19 @@ import json import os import warnings +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import Any, Callable, Dict, Optional, Union -from abc import ABC, abstractmethod import torch import torch.nn as nn from torch import Tensor -from torch.distributed._tensor import Replicate, Shard, DTensor +from torch.distributed._tensor import DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -28,39 +28,43 @@ ) from torch.nn import functional as F -from torchchat.utils.build_utils import find_multiple, get_precision - from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder -from torchtune.modules.model_fusion import DeepFusionModel from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder +from torchtune.modules.model_fusion import DeepFusionModel + +from torchchat.utils.build_utils import find_multiple, get_precision config_path = Path(f"{str(Path(__file__).parent)}/model_params") + def identity(**kwargs): if len(kwargs) != 1: raise ValueError("Only one argument is expected") return list(kwargs.values())[0] + class ModelType(Enum): TextOnly = "text_only" Llama3_1 = "llama3_1" Flamingo = "flamingo" + # Type for objects that can generate nn.Module instance ModuleLike = Union[nn.Module, Callable[..., nn.Module]] + @dataclass class ModelRecipe: """ The class describes and contains all supported model structures in torchchat. - + ModelRecipe represents a model as a collection of Transformer modules and a fusion module, providing a standardized and centralized way to define and build models in torchchat. Attributes: model_type (ModelType): The type of the model. modules (Dict[str, ModuleLike]): - A dictionary of ModuleLike modules, where each key is the module name and each + A dictionary of ModuleLike modules, where each key is the module name and each value is a ModuleLike object that generates the transformer. The names of the Transformer modules should match the corresponding names in the fusion class and the JSON file holding model hyperparameters. @@ -76,15 +80,15 @@ class ModelRecipe: def _text_only(cls): return cls( model_type=ModelType.TextOnly, - modules={'text': Transformer}, + modules={"text": Transformer}, fusion_class=identity, ) - + @classmethod def _llama3_1(cls): return cls( model_type=ModelType.Llama3_1, - modules={'text': llama3_1_builder}, + modules={"text": llama3_1_builder}, fusion_class=identity, ) @@ -92,13 +96,10 @@ def _llama3_1(cls): def _flamingo(cls): return cls( model_type=ModelType.Flamingo, - modules={ - 'encoder': flamingo_vision_encoder, - 'decoder': flamingo_decoder - }, + modules={"encoder": flamingo_vision_encoder, "decoder": flamingo_decoder}, fusion_class=DeepFusionModel, ) - + @classmethod def get_recipe(cls, model_type): if model_type == ModelType.TextOnly: @@ -110,6 +111,7 @@ def get_recipe(cls, model_type): else: raise ValueError(f"Can not find the model recipe for {model_type}") + @dataclass class TransformerArgs: block_size: int = 2048 @@ -170,14 +172,14 @@ def __init__( model_type: ModelType = ModelType.TextOnly, ) -> None: self._sanity_check(transformer_args, model_type) - + self.model_type = model_type if isinstance(transformer_args, TransformerArgs): assert model_type == ModelType.TextOnly self.transformer_args = {"text": transformer_args} else: self.transformer_args = transformer_args - + def _sanity_check( self, transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], @@ -195,11 +197,15 @@ def from_params(cls, params_path): # try to interpret as a single transformer config transformer_args: Dict[str, TransformerArgs] = {} transformer_args["text"] = TransformerArgs.from_params(loaded_params) - model_type = ModelType.TextOnly + if (model_type := loaded_params.get("model_type", None)) is None: + model_type = ModelType.TextOnly + except TypeError: # try to interpret as a dict of transformer configs model_type = ModelType(loaded_params["model_type"]) - transformer_args = {k: v for k, v in loaded_params.items() if k != "model_type"} + transformer_args = { + k: v for k, v in loaded_params.items() if k != "model_type" + } return cls(transformer_args, model_type) @classmethod @@ -281,11 +287,12 @@ class Model(ABC, nn.Module): """ The entrance for model construction in torchchat. """ + def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config self.model = self.build_model() - + def build_model(self) -> nn.Module: """ Builds a model based on the provided configuration. @@ -303,7 +310,7 @@ def build_model(self) -> nn.Module: modules[name] = module_class(config_args) return recipe.fusion_class(**modules) - + @abstractmethod def forward(self, *args, **kwargs): raise NotImplementedError("forward method is not implemented") @@ -373,7 +380,9 @@ def forward( ) -> Tensor: if encoder_input is None: return self.model(tokens, encoder_mask=encoder_mask) - return self.model(tokens, encoder_input=encoder_input, encoder_mask=encoder_mask) + return self.model( + tokens, encoder_input=encoder_input, encoder_mask=encoder_mask + ) def setup_caches(self, max_batch_size, dtype): self.model.setup_caches(max_batch_size, dtype=dtype) @@ -388,6 +397,7 @@ def reset_caches(self): ModelType.Llama3_1: Llama31Model, } + class Transformer(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() @@ -396,14 +406,16 @@ def __init__(self, config: TransformerArgs) -> None: self.tok_embeddings = ( nn.Embedding(config.vocab_size, config.dim) - if config.stage_idx == 0 else None + if config.stage_idx == 0 + else None ) # Use ModuleDict so that each layer can be assigned its layer ID in the original model self.layers = nn.ModuleDict() for layer_id in range( - layers_per_stage * config.stage_idx, layers_per_stage * (config.stage_idx + 1) + layers_per_stage * config.stage_idx, + layers_per_stage * (config.stage_idx + 1), ): self.layers[str(layer_id)] = TransformerBlock(config) @@ -449,7 +461,8 @@ def setup_caches(self, max_batch_size, max_seq_length): def distribute(self, device_mesh: DeviceMesh): if self.tok_embeddings: parallelize_module( - self.tok_embeddings, device_mesh, + self.tok_embeddings, + device_mesh, RowwiseParallel( input_layouts=Replicate(), output_layouts=Shard(1), @@ -464,7 +477,8 @@ def distribute(self, device_mesh: DeviceMesh): if self.output: parallelize_module( - self.output, device_mesh, + self.output, + device_mesh, ColwiseParallel( input_layouts=Shard(1), output_layouts=Replicate(), @@ -588,7 +602,9 @@ def distribute(self, device_mesh: DeviceMesh): parallelize_module(self.wq, device_mesh, ColwiseParallel()) parallelize_module(self.wk, device_mesh, ColwiseParallel()) parallelize_module(self.wv, device_mesh, ColwiseParallel()) - parallelize_module(self.wo, device_mesh, RowwiseParallel(output_layouts=Shard(1))) + parallelize_module( + self.wo, device_mesh, RowwiseParallel(output_layouts=Shard(1)) + ) # TODO: enable kv cache in distributed case self.kv_cache = None @@ -650,7 +666,9 @@ def __init__(self, config: TransformerArgs) -> None: def distribute(self, device_mesh: DeviceMesh): self.device_mesh = device_mesh parallelize_module(self.w1, device_mesh, ColwiseParallel()) - parallelize_module(self.w2, device_mesh, RowwiseParallel(output_layouts=Shard(1))) + parallelize_module( + self.w2, device_mesh, RowwiseParallel(output_layouts=Shard(1)) + ) parallelize_module(self.w3, device_mesh, ColwiseParallel()) def forward(self, x: Tensor) -> Tensor: @@ -677,9 +695,16 @@ def forward(self, x: Tensor) -> Tensor: def apply_scaling(freqs: torch.Tensor, rope_scaling: Dict[str, Any]): # Check for the presence of the required keys - required_keys = {"factor", "low_freq_factor", "high_freq_factor", "original_max_position_embeddings"} + required_keys = { + "factor", + "low_freq_factor", + "high_freq_factor", + "original_max_position_embeddings", + } if not required_keys.issubset(rope_scaling.keys()): - raise ValueError(f"Missing required keys in apply_scaling. Expected: {required_keys}") + raise ValueError( + f"Missing required keys in apply_scaling. Expected: {required_keys}" + ) scale_factor = rope_scaling["factor"] low_freq_factor = rope_scaling["low_freq_factor"] @@ -705,7 +730,11 @@ def apply_scaling(freqs: torch.Tensor, rope_scaling: Dict[str, Any]): def precompute_freqs_cis( - n_elem: int, seq_len: int, base: int = 10000, dtype=None, rope_scaling: Optional[Dict[str, Any]] = None + n_elem: int, + seq_len: int, + base: int = 10000, + dtype=None, + rope_scaling: Optional[Dict[str, Any]] = None, ) -> Tensor: if not dtype: dtype = get_precision() @@ -742,7 +771,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: try: from executorch.extension.pybindings import portable_lib as exec_lib - + # ET changed the way it's loading the custom ops so it's not included in portable_lib but has to be loaded separately. from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # no-qa