diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 0661d08f5..001ece603 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -17,13 +17,14 @@ import torch._inductor.config import torch.distributed as dist -from torchchat.distributed.utils import( +from torchchat.distributed.logging_utils import SingletonLogger + +from torchchat.distributed.utils import ( Color as color, CUDATrackTime, - init_distributed, GPUMemoryMonitor, + init_distributed, ) -from torchchat.distributed.logging_utils import SingletonLogger from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs from torchchat.model_config.model_config import resolve_model_config @@ -37,15 +38,6 @@ from torchchat.utils.quantize import quantize_model -from torchtune.models.convert_weights import meta_to_tune - -from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE - -from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune - -from torchtune.training import set_default_dtype - - @dataclass class BuilderArgs: checkpoint_path: Optional[Union[Path, str]] = None @@ -188,15 +180,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": tp = getattr(args, "tp", 1) chpt_from = getattr(args, "chpt_from", "hf") sdp_backend_dict = { - 'math': torch.nn.attention.SDPBackend.MATH, - 'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION, - 'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, - 'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION, + "math": torch.nn.attention.SDPBackend.MATH, + "flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION, + "efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + "cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION, } attention_backend = sdp_backend_dict[args.attention_backend] - if args.device == "cpu" and (args.attention_backend == "efficient_attention" - or args.attention_backend == "cudnn_attention"): - print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.") + if args.device == "cpu" and ( + args.attention_backend == "efficient_attention" + or args.attention_backend == "cudnn_attention" + ): + print( + f"Warning: {args.attention_backend} is not supported on CPU. Using math instead." + ) attention_backend = torch.nn.attention.SDPBackend.MATH return cls( checkpoint_dir=checkpoint_dir, @@ -238,12 +234,14 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs": speculative_builder_args.pte_path = None return speculative_builder_args + class TokenizerType(Enum): NONE = 0 TIKTOKEN = 1 SENTENCEPIECE = 2 HF_TOKENIZER = 3 + @dataclass class TokenizerArgs: tokenizer_path: Optional[Union[Path, str]] = None @@ -307,9 +305,9 @@ def validate_model( use_sentencepiece = not (use_tiktoken or use_hf_tokenizer) if ( - (is_tiktoken and not use_tiktoken) or - (is_hf_tokenizer and not use_hf_tokenizer) or - (is_sentencepiece and not use_sentencepiece) + (is_tiktoken and not use_tiktoken) + or (is_hf_tokenizer and not use_hf_tokenizer) + or (is_sentencepiece and not use_sentencepiece) ): raise RuntimeError( "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format( @@ -417,6 +415,7 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model: def _load_checkpoint(builder_args: BuilderArgs): if builder_args.params_table and builder_args.params_table.endswith("Tune"): + from torchtune.models.convert_weights import meta_to_tune print("Loading Tune checkpoint") meta_checkpoint = torch.load( str(builder_args.checkpoint_path), mmap=True, weights_only=True @@ -469,9 +468,15 @@ def _load_model_default(builder_args: BuilderArgs) -> Model: checkpoint = checkpoint["model"] if model.config.model_type == ModelType.Flamingo: + from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE + from torchtune.models.llama3_2_vision._convert_weights import ( + llama3_vision_meta_to_tune, + ) + from torchtune.training import set_default_dtype # TODO: Refactor this. For now, overwrite the model with model loaded from params_path - with set_default_dtype(builder_args.precision), torch.device( - builder_args.device + with ( + set_default_dtype(builder_args.precision), + torch.device(builder_args.device), ): # It doubles the model size the memory, with redundancies of the initialized weights. # model = Model.from_params(builder_args.params_path) @@ -507,6 +512,7 @@ def _load_model(builder_args: BuilderArgs) -> Model: # AOTI-compoiled model will load its own weights. # Release weights here to avoid OOM import gc + if hasattr(model, "model"): model.model = None gc.collect() @@ -564,6 +570,7 @@ def _initialize_model( def do_nothing(max_batch_size, max_seq_length): pass + model.setup_caches = do_nothing model.forward = torch._export.aot_load( @@ -601,6 +608,7 @@ def do_nothing(max_batch_size, max_seq_length): def do_nothing(max_batch_size, max_seq_length): pass + model.setup_caches = do_nothing model.forward = aoti_compiled_model @@ -652,12 +660,15 @@ def do_nothing(max_batch_size, max_seq_length): try: model = torch.load(builder_args.snapshot_path, weights_only=False) except Exception: - raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}") + raise RuntimeError( + f"Failed to load torchchat snapshot {builder_args.snapshot_path}" + ) # _active_backend() does not allow DSO & AOTI to be true. # Choose either. from torchchat.utils.build_utils import set_backend - set_backend (dso=True, pte=False, aoti_package=False) - if (model.config != config): + + set_backend(dso=True, pte=False, aoti_package=False) + if model.config != config: raise RuntimeError("loaded model architecture mismatch") ## ## import all libraries with custom kernels ans custom operators @@ -675,7 +686,9 @@ def do_nothing(max_batch_size, max_seq_length): logger = SingletonLogger.get_logger() gpu_memory_monitor = GPUMemoryMonitor("cuda") - logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") + logger.info( + f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}" + ) # Model-level config if builder_args.params_table: @@ -686,20 +699,16 @@ def do_nothing(max_batch_size, max_seq_length): config = TransformerArgs.from_params(model_config.transformer_args["text"]) logger.info(f"Transformer Config: {config}") - #TODO: Move into head of file after solving circular import - from torchchat.distributed.checkpoint_utils import ( - load_model_weights, - ) + # TODO: Move into head of file after solving circular import + from torchchat.distributed.checkpoint_utils import load_model_weights # Validate pipeline degree assert config.n_layers % pp_degree == 0 # Create device mesh device_mesh = dist.init_device_mesh( - "cuda", - (pp_degree, tp_degree), - mesh_dim_names=("pp", "tp") - ) + "cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp") + ) tp_mesh = device_mesh["tp"] pp_mesh = device_mesh["pp"] logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") @@ -728,7 +737,13 @@ def do_nothing(max_batch_size, max_seq_length): # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") with CUDATrackTime() as timer: - load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from) + load_model_weights( + model, + builder_args.distribution_path, + device, + config, + builder_args.chpt_from, + ) logger.info( f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" @@ -742,7 +757,7 @@ def do_nothing(max_batch_size, max_seq_length): # lanes. # TODO: bump up the lane count pipeline_lanes = 1 - seqlen_prefill=1024 + seqlen_prefill = 1024 with device: model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes) diff --git a/torchchat/generate.py b/torchchat/generate.py index 4f90b316f..ed1b27fa6 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -10,11 +10,11 @@ import os import textwrap import time -from concurrent import futures -from functools import partial from abc import ABC, abstractmethod +from concurrent import futures from dataclasses import dataclass +from functools import partial from io import BytesIO from os import PathLike from pathlib import Path @@ -25,18 +25,10 @@ import torch._inductor.config import torch.distributed as dist import torch.multiprocessing as mp -from torch.distributed.pipelining import PipelineStage, ScheduleGPipe -from torch._C import _SDPBackend as SDPBackend from PIL import Image - -# torchtune model definition dependencies -from torchtune.data import Message, padded_collate_tiled_images_and_mask - -from torchtune.generation import sample as tune_sample - -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform -from torchtune.training import set_default_dtype +from torch._C import _SDPBackend as SDPBackend +from torch.distributed.pipelining import PipelineStage, ScheduleGPipe from torchchat.cli.builder import ( _initialize_model, @@ -44,10 +36,7 @@ BuilderArgs, TokenizerArgs, ) -from torchchat.distributed.utils import ( - Color as color, - run_in_dist_env, -) +from torchchat.distributed.utils import Color as color, run_in_dist_env from torchchat.model import Model, ModelType from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info @@ -59,17 +48,16 @@ class NoOpLogger: def __no_op(self, *_, **__): pass + def __getattr__(self, name): return self.__no_op -logger = ( - NoOpLogger() if os.getenv("LOG_LEVEL") is None - else logging.getLogger(__name__) -) +logger = NoOpLogger() if os.getenv("LOG_LEVEL") is None else logging.getLogger(__name__) ## Chat Formatters ############################################################# + class _ChatFormatter(ABC): # Messages can arrive as a standard dict with "role" and "content" as @@ -145,7 +133,9 @@ def encode_dialog_prompt( tokens.extend(self._encode_message(message)) # Add the start of an assistant message for the model to complete. if add_generation_prompt and dialog and dialog[-1]["role"] != "assistant": - tokens.extend(self._encode_header("assistant")) # Pass role directly as a string + tokens.extend( + self._encode_header("assistant") + ) # Pass role directly as a string return tokens @@ -167,7 +157,7 @@ def _get_content_str(message: _ChatFormatter.MESSAGE_TYPE) -> str: def encode_dialog_prompt( self, dialog: _ChatFormatter.DIALOG_TYPE, - add_generation_prompt: bool = True, # UNUSED + add_generation_prompt: bool = True, # UNUSED ) -> List[int]: new_turn = True tokens = [] @@ -190,11 +180,11 @@ def encode_dialog_prompt( return tokens - class HFTokenizerChatFormatter(_ChatFormatter): """Chat formatter that uses the built-in formatting capabilities of an HF tokenizer instance """ + def encode_dialog_prompt( self, dialog: _ChatFormatter.DIALOG_TYPE, @@ -206,8 +196,10 @@ def encode_dialog_prompt( logger.debug("Formatted chat prompt:\n%s", rendered) return self.tokenizer.encode(rendered) + ## Generation ################################################################## + @dataclass class GeneratorArgs: prompt: Optional[str] = ( @@ -264,7 +256,10 @@ def from_args(cls, args): pte_path = getattr(args, "pte_path", None) aoti_package_path = getattr(args, "aoti_package_path", None) sequential_prefill = ( - args.sequential_prefill or bool(aoti_package_path) or bool(pte_path) or bool(dso_path) + args.sequential_prefill + or bool(aoti_package_path) + or bool(pte_path) + or bool(dso_path) ) # Validate that all image prompts exist before expensive model load @@ -322,9 +317,10 @@ def __init__( quantize: bool, draft_quantize: bool, ): - torch._inductor.config.coordinate_descent_tuning = ( - builder_args.device not in ["cpu", "mps"] - ) + torch._inductor.config.coordinate_descent_tuning = builder_args.device not in [ + "cpu", + "mps", + ] torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future @@ -336,7 +332,7 @@ def __init__( self.draft_quantize = draft_quantize self.is_torchtune_model = generator_args.is_torchtune_model self.dtype = builder_args.precision - self.get_user_input : Callable = input + self.get_user_input: Callable = input self.rank: Optional[int] = None @@ -455,6 +451,7 @@ def prefill( assert input_pos.size(0) == width if self.model.config.model_type == ModelType.Flamingo: + from torchtune.generation import sample as tune_sample assert batch is not None, "Flamingo requires batch" # TODO: Verify sequential prefill works with multimodal models @@ -744,7 +741,9 @@ def generate( decoder_max_seq_len=max_seq_length, ) else: - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + model.setup_caches( + max_batch_size=1, max_seq_length=max_seq_length + ) if is_speculative and draft_model is not model: draft_model.setup_caches( max_batch_size=1, @@ -870,6 +869,13 @@ def _gen_model_input( max_new_tokens: Optional[int] = None, max_seq_len: Optional[int] = 2048, ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: + # torchtune model definition dependencies + from torchtune.data import Message, padded_collate_tiled_images_and_mask + from torchtune.models.llama3_2_vision._model_builders import ( + llama3_2_vision_transform, + ) + from torchtune.training import set_default_dtype + """ Convert prompt and image prompts into consumable model input args. @@ -890,7 +896,9 @@ def _gen_model_input( # Single String prompt if isinstance(prompt, str): encoded = self.encode_tokens( - prompt, bos=self.model.config.tokenizer_prepend_bos, device=self.builder_args.device + prompt, + bos=self.model.config.tokenizer_prepend_bos, + device=self.builder_args.device, ) # List of dialog else: @@ -983,7 +991,10 @@ def _gen_model_input( if image_found: batch = padded_collate_tiled_images_and_mask( - [data], pad_direction="left", pad_max_images=1, pad_max_tiles=transform.max_num_tiles + [data], + pad_direction="left", + pad_max_images=1, + pad_max_tiles=transform.max_num_tiles, ) encoded = batch.pop("tokens").to(device).view(-1) seq_len = encoded.size(0) @@ -1099,7 +1110,9 @@ def chat( "Do you want to enter a system prompt? Enter y for yes and anything else for no. \n" ) if get_system_prompt == "y" or get_system_prompt == "Y": - self.system_prompt = self.get_user_input("What is your system prompt? \n") + self.system_prompt = self.get_user_input( + "What is your system prompt? \n" + ) # `is_torchtune_model` is a misnomer since it doesn't capture all # torchtune models (i.e. Flamingo) @@ -1152,7 +1165,8 @@ def chat( ) messages_to_encode.append({"role": "user", "content": prompt}) encoded = self.chat_formatter.encode_dialog_prompt( - messages_to_encode, add_generation_prompt=True, + messages_to_encode, + add_generation_prompt=True, ) encoded = torch.tensor( encoded, dtype=torch.int, device=self.builder_args.device @@ -1331,7 +1345,6 @@ def callback(x, *, done_generating=False): print(f"Memory used: {torch.npu.max_memory_reserved() / 1e9:.02f} GB") - class DistributedGenerator(LocalGenerator): def __init__( self, @@ -1342,10 +1355,12 @@ def __init__( profile: Optional[Path], quantize: bool, draft_quantize: bool, - ): + ): is_speculative = speculative_builder_args.checkpoint_path is not None - assert is_speculative == False, "Distributed inference with pp > 1 does not support speculative inference yet." + assert ( + is_speculative == False + ), "Distributed inference with pp > 1 does not support speculative inference yet." super().__init__( builder_args, speculative_builder_args, @@ -1373,7 +1388,9 @@ def distributed_input(prompt: str) -> str: if builder_args.pp > 1: self.seqlen_prefill = 1024 # sequence length for prefill stage - logger.warn(f"{color.yellow}Pipeline parallelism is still experimental and might be slow{color.reset}") + logger.warn( + f"{color.yellow}Pipeline parallelism is still experimental and might be slow{color.reset}" + ) pp_mesh = self.model.device_mesh["pp"] self.pp_rank = pp_mesh.get_local_rank() @@ -1385,9 +1402,12 @@ def distributed_input(prompt: str) -> str: self.first_pp_rank = 0 self.last_pp_rank = self.pp_degree - 1 - - self.first_pp_rank_global_id = dist.get_global_rank(self.pp_group, self.first_pp_rank) - self.last_pp_rank_global_id = dist.get_global_rank(self.pp_group, self.last_pp_rank) + self.first_pp_rank_global_id = dist.get_global_rank( + self.pp_group, self.first_pp_rank + ) + self.last_pp_rank_global_id = dist.get_global_rank( + self.pp_group, self.last_pp_rank + ) self.prefiller = self.create_prefill_stage() self.decoder = self.create_decode_stage() @@ -1396,7 +1416,9 @@ def __del__(self): dist.destroy_process_group() # Helper function to get example inputs and outputs for the stages. - def get_example_ins_outs(self, batch_size: int , seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: + def get_example_ins_outs( + self, batch_size: int, seqlen: int + ) -> Tuple[torch.Tensor, torch.Tensor]: """ This function generates example inputs and outputs for the prefill and decode stages. @@ -1408,10 +1430,18 @@ def get_example_ins_outs(self, batch_size: int , seqlen: int) -> Tuple[torch.Ten 0, self.model.config.vocab_size, (batch_size, seqlen), device=self.device ) activation = torch.rand( - batch_size, seqlen, self.model.config.dim, device=self.device, dtype=model_dtype + batch_size, + seqlen, + self.model.config.dim, + device=self.device, + dtype=model_dtype, ) logits = torch.rand( - batch_size, seqlen, self.model.config.vocab_size, device=self.device, dtype=model_dtype + batch_size, + seqlen, + self.model.config.vocab_size, + device=self.device, + dtype=model_dtype, ) example_inputs = (mb_ids if self.pp_rank == self.first_pp_rank else activation,) example_outputs = (logits if self.pp_rank == self.last_pp_rank else activation,) @@ -1427,8 +1457,12 @@ def create_prefill_stage(self): batch_size = 1 # Create prefill stage - logger.debug(f"Creating pipeline stage for prefill {self.pp_rank=}, {self.pp_degree=}") - example_inputs, example_outputs = self.get_example_ins_outs(batch_size, self.seqlen_prefill) + logger.debug( + f"Creating pipeline stage for prefill {self.pp_rank=}, {self.pp_degree=}" + ) + example_inputs, example_outputs = self.get_example_ins_outs( + batch_size, self.seqlen_prefill + ) prefill_stage = PipelineStage( self.model, self.pp_rank, @@ -1460,7 +1494,9 @@ def create_decode_stage(self): # Create decode stage # logger.info(f"Creating pipeline stage for decode {self.pp_rank=}, {self.pp_degree=}") - example_inputs, example_outputs = self.get_example_ins_outs(batch_size, seqlen_decode) + example_inputs, example_outputs = self.get_example_ins_outs( + batch_size, seqlen_decode + ) decode_stage = PipelineStage( self.model, self.pp_rank, @@ -1501,18 +1537,25 @@ def prefill( **sampling_kwargs, ) - pad_token_id = self.tokenizer.pad_id if self.tokenizer.pad_id is not None else self.tokenizer.eos_id + pad_token_id = ( + self.tokenizer.pad_id + if self.tokenizer.pad_id is not None + else self.tokenizer.eos_id + ) prompt_length = x.size(1) padded_seq = torch.full( - (1, self.seqlen_prefill), pad_token_id, dtype=torch.int64, device=self.device - ) - padded_seq[:,:prompt_length] = x + (1, self.seqlen_prefill), + pad_token_id, + dtype=torch.int64, + device=self.device, + ) + padded_seq[:, :prompt_length] = x input_pos = torch.arange( self.seqlen_prefill, device=self.device, dtype=torch.int, - ) + ) # Prefill phase # Run context input through pipeline @@ -1528,7 +1571,9 @@ def prefill( self.prefiller.step(**kwargs) if self.pp_rank == self.last_pp_rank: - new_token = self.sample(logits[:,:prompt_length], need_probs=False, **sampling_kwargs)[0] + new_token = self.sample( + logits[:, :prompt_length], need_probs=False, **sampling_kwargs + )[0] if self.pp_rank != self.first_pp_rank: dist.send( new_token, @@ -1604,8 +1649,8 @@ def decode_one_token( src=self.last_pp_rank_global_id, group=self.pp_group, ) - #TODO: Why do we get 2d tensor here? - new_token=new_token[0] + # TODO: Why do we get 2d tensor here? + new_token = new_token[0] return new_token, None def sample( @@ -1623,10 +1668,8 @@ def sample( return idx_next, probs -def run_generator( - args, - rank: Optional[int] =None - ): + +def run_generator(args, rank: Optional[int] = None): """ This function creates and executes a generator """ @@ -1634,7 +1677,7 @@ def run_generator( speculative_builder_args = BuilderArgs.from_speculative_args(args) tokenizer_args = TokenizerArgs.from_args(args) generator_args = GeneratorArgs.from_args(args) - #Setup rank 1 and up to suppress log messages and print messages + # Setup rank 1 and up to suppress log messages and print messages if builder_args.distributed and rank != 0: logger.setLevel(logging.CRITICAL) context = contextlib.redirect_stdout(None) @@ -1663,18 +1706,21 @@ def run_generator( for _ in gen.chat(generator_args): pass + def main(args): builder_args = BuilderArgs.from_args(args) if builder_args.distributed: world_size = builder_args.tp * builder_args.pp - ctx = mp.get_context('spawn') - with futures.ProcessPoolExecutor(max_workers=world_size-1, mp_context=ctx) as executor: - for i in range(1,world_size): + ctx = mp.get_context("spawn") + with futures.ProcessPoolExecutor( + max_workers=world_size - 1, mp_context=ctx + ) as executor: + for i in range(1, world_size): fn = partial(run_generator, args, i) executor.submit(run_in_dist_env, world_size, i, fn) - #Starting rank 0 + # Starting rank 0 fn = partial(run_generator, args, 0) run_in_dist_env(world_size, 0, fn) else: diff --git a/torchchat/model.py b/torchchat/model.py index 9722ca240..4605aea33 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -8,13 +8,13 @@ import os import warnings from abc import ABC, abstractmethod +from collections.abc import Hashable from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import Any, Callable, Dict, Optional, Union -from collections.abc import Hashable import torch import torch.nn as nn @@ -37,14 +37,6 @@ except Exception: pass -from torchtune.models.clip import clip_vision_encoder -from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder -from torchtune.models.llama3_2_vision._component_builders import ( - llama3_2_vision_decoder, - llama3_2_vision_encoder, -) -from torchtune.modules.model_fusion import DeepFusionModel - from torchchat.utils.build_utils import find_multiple, get_precision config_path = Path(f"{str(Path(__file__).parent)}/model_params") @@ -67,7 +59,6 @@ def identity(**kwargs): return list(kwargs.values())[0] - class MultiModalProjector(nn.Module): def __init__(self, in_channels: int, out_channels: int, act: nn.Module): super().__init__() @@ -105,9 +96,9 @@ def __init__( self.decoder.__setattr__(token_embedding_name, None) self.mm_projector = MultiModalProjector( - in_channels=mm_proj_in_channels, - out_channels=mm_proj_out_channels, - act=mm_proj_activation, + in_channels=mm_proj_in_channels, + out_channels=mm_proj_out_channels, + act=mm_proj_activation, ) def forward( @@ -214,6 +205,10 @@ def _text_only(cls): @classmethod def _llama3_1(cls): + from torchtune.models.llama3_1._component_builders import ( + llama3_1 as llama3_1_builder, + ) + return cls( model_type=ModelType.Llama3_1, modules={"text": llama3_1_builder}, @@ -222,23 +217,28 @@ def _llama3_1(cls): @classmethod def _flamingo(cls): + from torchtune.models.llama3_2_vision._component_builders import ( + llama3_2_vision_decoder, + llama3_2_vision_encoder, + ) + from torchtune.modules.model_fusion import DeepFusionModel + return cls( model_type=ModelType.Flamingo, modules={ "encoder": llama3_2_vision_encoder, - "decoder": llama3_2_vision_decoder + "decoder": llama3_2_vision_decoder, }, fusion_class=DeepFusionModel, ) @classmethod def _llava(cls): + from torchtune.models.clip import clip_vision_encoder + return cls( model_type=ModelType.Llava, - modules={ - 'encoder': clip_vision_encoder, - 'decoder': Transformer - }, + modules={"encoder": clip_vision_encoder, "decoder": Transformer}, fusion_class=ConcateFusion, ) @@ -504,10 +504,17 @@ def build_model(self) -> nn.Module: # Temporary add extra params to the DeepFusionModel. # TODO: Remove it once we can make fusion model configurable in model_param. - if recipe.fusion_class == DeepFusionModel: - modules["encoder_trainable"] = False - modules["decoder_trainable"] = False - modules["fusion_trainable"] = False + try: + from torchtune.modules.model_fusion import DeepFusionModel + + if recipe.fusion_class == DeepFusionModel: + modules["encoder_trainable"] = False + modules["decoder_trainable"] = False + modules["fusion_trainable"] = False + except ModuleNotFoundError: + # In case it is actually DeepFusionModel and torchtune is not installed, + # it will fail with an error further without unexpected behavior. + pass return recipe.fusion_class(**modules) @@ -627,7 +634,12 @@ def forward( post_tokens: Optional[Tensor] = None, input_pos: Optional[Tensor] = None, ) -> Tensor: - return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos) + return self.model( + tokens, + encoder_input=encoder_input, + post_tokens=post_tokens, + input_pos=input_pos, + ) def setup_caches(self, max_batch_size, max_seq_length): self.model.setup_caches(max_batch_size, max_seq_length) @@ -678,7 +690,9 @@ def __init__(self, config: TransformerArgs) -> None: def load_hook(self, state_dict, prefix, *args): """Handle tied embeddings at load time""" if self.config.tie_word_embeddings: - state_dict.setdefault("model.output.weight", state_dict["model.tok_embeddings.weight"]) + state_dict.setdefault( + "model.output.weight", state_dict["model.tok_embeddings.weight"] + ) def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1): if ( @@ -727,7 +741,9 @@ def distribute(self, device_mesh: DeviceMesh): ColwiseParallel(output_layouts=Replicate()), ) - def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 0) -> Tensor: + def forward( + self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 0 + ) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] @@ -771,11 +787,24 @@ def distribute(self, device_mesh: DeviceMesh): self.feed_forward.distribute(device_mesh) def forward( - self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, cache_lane: int = 0 + self, + x: Tensor, + input_pos: Tensor, + freqs_cis: Tensor, + mask: Tensor, + cache_lane: int = 0, ) -> Tensor: - h = x + self.attention( - self.attention_norm(x), freqs_cis, mask, input_pos, cache_lane=cache_lane - ) * self.residual_multiplier + h = ( + x + + self.attention( + self.attention_norm(x), + freqs_cis, + mask, + input_pos, + cache_lane=cache_lane, + ) + * self.residual_multiplier + ) out = h + self.feed_forward(self.ffn_norm(h)) * self.residual_multiplier return out @@ -788,12 +817,18 @@ def __init__(self, config: TransformerArgs): # key, query, value projections for all heads, but in a batch # total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.attention_bias) - self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=config.attention_bias) + self.wq = nn.Linear( + config.dim, config.n_heads * config.head_dim, bias=config.attention_bias + ) self.wk = nn.Linear( - config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias + config.dim, + config.n_local_heads * config.head_dim, + bias=config.attention_bias, ) self.wv = nn.Linear( - config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias + config.dim, + config.n_local_heads * config.head_dim, + bias=config.attention_bias, ) self.wo = nn.Linear(config.dim, config.dim, bias=config.attention_bias) @@ -812,10 +847,12 @@ def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1): if hasattr(self, "tp_degree"): n_local_heads = self.n_local_heads // self.tp_degree - self.kv_cache = nn.ModuleList([ - KVCache(max_batch_size, max_seq_length, n_local_heads, self.head_dim) - for _ in range(cache_lanes) - ]) + self.kv_cache = nn.ModuleList( + [ + KVCache(max_batch_size, max_seq_length, n_local_heads, self.head_dim) + for _ in range(cache_lanes) + ] + ) def load_hook(self, state_dict, prefix, *args): # if prefix + "wq.weight" in state_dict: @@ -921,9 +958,15 @@ def forward( class FeedForward(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() - self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias) - self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=config.feed_forward_bias) - self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias) + self.w1 = nn.Linear( + config.dim, config.hidden_dim, bias=config.feed_forward_bias + ) + self.w2 = nn.Linear( + config.hidden_dim, config.dim, bias=config.feed_forward_bias + ) + self.w3 = nn.Linear( + config.dim, config.hidden_dim, bias=config.feed_forward_bias + ) def distribute(self, device_mesh: DeviceMesh): parallelize_module(self.w1, device_mesh, ColwiseParallel()) @@ -1011,13 +1054,13 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ try: + # For llama::sdpa_with_kv_cache.out, preprocess ops + from executorch.extension.llm.custom_ops import custom_ops # no-qa from executorch.extension.pybindings import portable_lib as exec_lib # ET changed the way it's loading the custom ops so it's not included in portable_lib but has to be loaded separately. # For quantized_decomposed ops from executorch.kernels import quantized # no-qa - # For llama::sdpa_with_kv_cache.out, preprocess ops - from executorch.extension.llm.custom_ops import custom_ops # no-qa class PTEModel(nn.Module): def __init__(self, config, path) -> None: @@ -1025,7 +1068,9 @@ def __init__(self, config, path) -> None: self.config = config self.model_ = exec_lib._load_for_executorch(str(path)) - self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"]) + self.text_transformer_args = TransformerArgs.from_params( + self.config.transformer_args["text"] + ) # TODO: attempt to use "get_max_seq_len" method on the model after # ExecuTorch bug is fixed. max_seq_len = 128 diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 0d1d3dce7..42c885b9e 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -13,18 +13,14 @@ from dataclasses import dataclass from io import BytesIO from pwd import getpwuid -from typing import Any, Dict, List, Optional, Union, Type +from typing import Any, Dict, List, Optional, Type, Union import torch from PIL import Image -from torchtune.data import Message, padded_collate_tiled_images_and_mask - -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform - from torchchat.cli.download import is_model_downloaded, load_model_configs -from torchchat.generate import LocalGenerator, DistributedGenerator, GeneratorArgs +from torchchat.generate import DistributedGenerator, GeneratorArgs, LocalGenerator from torchchat.model import FlamingoModel from torchchat.utils.build_utils import device_sync @@ -304,7 +300,7 @@ def __init__(self, *args, **kwargs): def _gen_model_inputs_from_openai_completion_request( self, completion_request: CompletionRequest - ) -> List[Message]: + ) -> List: """Generate model inputs from an OpenAI completion request. Args: @@ -392,7 +388,7 @@ def callback(x, *, done_generating=False): device_sync(device=self.builder_args.device) buffer = [] - ILLEGAL_CHAR = '\ufffd' + ILLEGAL_CHAR = "\ufffd" # Process each token, metrics tuple yielded by Generator.generate. for y, _ in self.generate( model=self.model, @@ -495,7 +491,14 @@ def create_openai_api_generator(distributed: bool) -> Type: """ # Base class order matters to make sure OpenAiApiGeneratorMixin overrides methods in DistributedGenerator and Generator - return type('OpenAiApiGenerator', (OpenAiApiGeneratorMixin, DistributedGenerator if distributed else LocalGenerator), {}) + return type( + "OpenAiApiGenerator", + ( + OpenAiApiGeneratorMixin, + DistributedGenerator if distributed else LocalGenerator, + ), + {}, + ) """