From f4bf00bb435717cb2d95fde37243bd4e1c6046db Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Sun, 15 Sep 2024 11:24:25 -0700 Subject: [PATCH 01/27] llava init --- torchchat/model.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/torchchat/model.py b/torchchat/model.py index 2cba9a032..85ebdb81d 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -41,10 +41,41 @@ def identity(**kwargs): raise ValueError("Only one argument is expected") return list(kwargs.values())[0] + +class ConcateFusion(nn.Module): + def __init__(self, encoder: nn.Module, decoder: nn.Module): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, + tokens: Tensor, + *, + post_tokens: Optional[Tensor] = None, + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None,) -> Tensor: + # split prompt from img tag into before img and after img + # concate before img, image result and after img into a large prompt + # forward that to text transformer + # resturn result + + if encoder_input: + encoder_output = self.encoder( + encoder_input, + ) + + + + + + class ModelType(Enum): TextOnly = "text_only" Llama3_1 = "llama3_1" Flamingo = "flamingo" + Llava = "llava" # Type for objects that can generate nn.Module instance ModuleLike = Union[nn.Module, Callable[..., nn.Module]] @@ -99,6 +130,17 @@ def _flamingo(cls): fusion_class=DeepFusionModel, ) + @classmethod + def _llava(cls): + return cls( + model_type=ModelType.Llava, + modules={ + 'te': flamingo_vision_encoder, + 'decoder': llama3_1_builder + }, + fusion_class=DeepFusionModel, + ) + @classmethod def get_recipe(cls, model_type): if model_type == ModelType.TextOnly: @@ -107,6 +149,8 @@ def get_recipe(cls, model_type): return cls._flamingo() elif model_type == ModelType.Llama3_1: return cls._llama3_1() + elif model_type == ModelType.Llava: + return cls._llava() else: raise ValueError(f"Can not find the model recipe for {model_type}") From 2cabbe76d0fe83831299a9bf8de4c29d268e990a Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Sun, 15 Sep 2024 13:08:31 -0700 Subject: [PATCH 02/27] 2/n llava init --- torchchat/model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchchat/model.py b/torchchat/model.py index 85ebdb81d..af6c6207e 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -65,8 +65,13 @@ def forward(self, encoder_output = self.encoder( encoder_input, ) + + def _gen_mm_embedding(self, tokens: Tensor, *, encoder_input: Optional[Tensor], post_tokens: Optional[Tensor]): + assert bool(encoder_input) == bool(post_tokens), "encoder_input and post_tokens must be both None or not None" + if encoder_input is None: + return tokens - + From 353fafe89dca51bbc5b92dc9d4268843ae4fc233 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Sun, 15 Sep 2024 17:38:33 -0700 Subject: [PATCH 03/27] 3/n llava init --- torchchat/model.py | 46 ++++++++++++++++++++------- torchchat/model_params/llava-1.5.json | 23 ++++++++++++++ 2 files changed, 58 insertions(+), 11 deletions(-) create mode 100644 torchchat/model_params/llava-1.5.json diff --git a/torchchat/model.py b/torchchat/model.py index af6c6207e..3904fe632 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -42,40 +42,64 @@ def identity(**kwargs): return list(kwargs.values())[0] +class MultiModalProjector(nn.Module): + def __init__(self, args: ProjectorArgs): + super().__init__() + + self.linear_1 = nn.Linear(args.in_channels, args.out_channels, bias=True) + self.act = args.activation + self.linear_2 = nn.Linear(args.out_channels, args.out_channels, bias=True) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + class ConcateFusion(nn.Module): - def __init__(self, encoder: nn.Module, decoder: nn.Module): + def __init__(self, encoder: nn.Module, decoder: nn.Module, token_embedding_name="tok_embeddings", mm_proj_in_channels=1024, mm_proj_out_channels=4096, mm_proj_activation=nn.GELU): super().__init__() self.encoder = encoder self.decoder = decoder + # esclate the embedding layer outside decoder llava model need to fuse + # the text and image embedding together before passing to decoder. + self.tok_embeddings = getattr(self.decoder, token_embedding_name) + + # set the embedding layer in decoder to None to jump the embedding layer over in decoder + self.decoder.__setattr__(token_embedding_name) = None + + self.mm_projector = MultiModalProjector(ProjectorArgs(in_channels=mm_proj_in_channels, out_channels=mm_proj_out_channels, activation=mm_proj_activation)) + def forward(self, tokens: Tensor, *, post_tokens: Optional[Tensor] = None, - mask: Optional[torch.Tensor] = None, encoder_input: Optional[Tensor] = None, encoder_mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None,) -> Tensor: - # split prompt from img tag into before img and after img - # concate before img, image result and after img into a large prompt - # forward that to text transformer - # resturn result - if encoder_input: encoder_output = self.encoder( encoder_input, ) + else: + encoder_output = None + + decoder_input = self._get_decoder_input(tokens, encoder_input=encoder_input, post_tokens=post_tokens) + return self.decoder(decoder_input) - def _gen_mm_embedding(self, tokens: Tensor, *, encoder_input: Optional[Tensor], post_tokens: Optional[Tensor]): + def _get_decoder_input(self, tokens: Tensor, *, encoder_input: Optional[Tensor], post_tokens: Optional[Tensor]): assert bool(encoder_input) == bool(post_tokens), "encoder_input and post_tokens must be both None or not None" if encoder_input is None: - return tokens + return self.tok_embeddings(tokens) + else: + pre_img_embed = self.tok_embeddings(tokens) + post_img_embed = self.tok_embeddings(post_tokens) + return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1) - - class ModelType(Enum): TextOnly = "text_only" Llama3_1 = "llama3_1" diff --git a/torchchat/model_params/llava-1.5.json b/torchchat/model_params/llava-1.5.json new file mode 100644 index 000000000..6629b4dab --- /dev/null +++ b/torchchat/model_params/llava-1.5.json @@ -0,0 +1,23 @@ +@dataclass +class VisionArgs: + tile_size: int = 336 + patch_size: int = 14 + embed_dim: int = 1024 + num_layers: int = 24 + num_heads: int = 16 + out_indices: List[int] = field(default_factory=list) + output_cls_projection: bool = False + max_num_tiles: int = 1 + in_channels: int = 3 + intermediate_act: nn.Module = QuickGELUActivation() + + def __post_init__(self): + if not self.out_indices: + self.out_indices = [self.num_layers - 1] + + +@dataclass +class ProjectorArgs: + in_channels: int = 1024 + out_channels: int = 4096 + activation: nn.Module = nn.GELU() From 728fc462d94c2c24170872f2690c19cdeeac615f Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Sun, 15 Sep 2024 17:42:44 -0700 Subject: [PATCH 04/27] reformat llava --- torchchat/model.py | 58 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index 3904fe632..5c993e67a 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -1,24 +1,19 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. import json import os import warnings +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import Any, Callable, Dict, Optional, Union -from abc import ABC, abstractmethod import torch import torch.nn as nn from torch import Tensor -from torch.distributed._tensor import Replicate, Shard, DTensor +from torch.distributed._tensor import DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -31,11 +26,12 @@ from torchchat.utils.build_utils import find_multiple, get_precision from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder -from torchtune.modules.model_fusion import DeepFusionModel from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder +from torchtune.modules.model_fusion import DeepFusionModel config_path = Path(f"{str(Path(__file__).parent)}/model_params") + def identity(**kwargs): if len(kwargs) != 1: raise ValueError("Only one argument is expected") @@ -56,8 +52,17 @@ def forward(self, image_features): hidden_states = self.linear_2(hidden_states) return hidden_states + class ConcateFusion(nn.Module): - def __init__(self, encoder: nn.Module, decoder: nn.Module, token_embedding_name="tok_embeddings", mm_proj_in_channels=1024, mm_proj_out_channels=4096, mm_proj_activation=nn.GELU): + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + token_embedding_name="tok_embeddings", + mm_proj_in_channels=1024, + mm_proj_out_channels=4096, + mm_proj_activation=nn.GELU, + ): super().__init__() self.encoder = encoder self.decoder = decoder @@ -67,36 +72,53 @@ def __init__(self, encoder: nn.Module, decoder: nn.Module, token_embedding_name= self.tok_embeddings = getattr(self.decoder, token_embedding_name) # set the embedding layer in decoder to None to jump the embedding layer over in decoder - self.decoder.__setattr__(token_embedding_name) = None + self.decoder.__setattr__(token_embedding_name, None) - self.mm_projector = MultiModalProjector(ProjectorArgs(in_channels=mm_proj_in_channels, out_channels=mm_proj_out_channels, activation=mm_proj_activation)) + self.mm_projector = MultiModalProjector( + ProjectorArgs( + in_channels=mm_proj_in_channels, + out_channels=mm_proj_out_channels, + activation=mm_proj_activation, + ) + ) - def forward(self, + def forward( + self, tokens: Tensor, *, post_tokens: Optional[Tensor] = None, encoder_input: Optional[Tensor] = None, encoder_mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None,) -> Tensor: + input_pos: Optional[torch.Tensor] = None, + ) -> Tensor: if encoder_input: encoder_output = self.encoder( encoder_input, ) else: encoder_output = None - - decoder_input = self._get_decoder_input(tokens, encoder_input=encoder_input, post_tokens=post_tokens) + + decoder_input = self._get_decoder_input( + tokens, encoder_input=encoder_input, post_tokens=post_tokens + ) return self.decoder(decoder_input) - def _get_decoder_input(self, tokens: Tensor, *, encoder_input: Optional[Tensor], post_tokens: Optional[Tensor]): - assert bool(encoder_input) == bool(post_tokens), "encoder_input and post_tokens must be both None or not None" + def _get_decoder_input( + self, + tokens: Tensor, + *, + encoder_input: Optional[Tensor], + post_tokens: Optional[Tensor], + ): + assert bool(encoder_input) == bool( + post_tokens + ), "encoder_input and post_tokens must be both None or not None" if encoder_input is None: return self.tok_embeddings(tokens) else: pre_img_embed = self.tok_embeddings(tokens) post_img_embed = self.tok_embeddings(post_tokens) return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1) - From 215331d2e6c7d2ef64efe084b8d4b3c329c663be Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 16 Sep 2024 01:40:24 -0700 Subject: [PATCH 05/27] 3/n llava --- torchchat/model.py | 31 ++++++++++++++---- torchchat/model_params/llava-1.5.json | 47 ++++++++++++++------------- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index 5c993e67a..7c89b53b2 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -28,10 +28,20 @@ from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder from torchtune.modules.model_fusion import DeepFusionModel +from torchtune.models.clip import clip_vision_encoder config_path = Path(f"{str(Path(__file__).parent)}/model_params") +class QuickGELUActivation(nn.Module): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input): + return input * torch.sigmoid(1.702 * input) + + def identity(**kwargs): if len(kwargs) != 1: raise ValueError("Only one argument is expected") @@ -99,7 +109,7 @@ def forward( encoder_output = None decoder_input = self._get_decoder_input( - tokens, encoder_input=encoder_input, post_tokens=post_tokens + tokens, encoder_output=encoder_output, post_tokens=post_tokens ) return self.decoder(decoder_input) @@ -107,16 +117,17 @@ def _get_decoder_input( self, tokens: Tensor, *, - encoder_input: Optional[Tensor], + encoder_output: Optional[Tensor], post_tokens: Optional[Tensor], ): - assert bool(encoder_input) == bool( + assert bool(encoder_output) == bool( post_tokens ), "encoder_input and post_tokens must be both None or not None" - if encoder_input is None: + if encoder_output is None: return self.tok_embeddings(tokens) else: pre_img_embed = self.tok_embeddings(tokens) + image_embeds = self.mm_projector(encoder_output) post_img_embed = self.tok_embeddings(post_tokens) return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1) @@ -261,7 +272,7 @@ class ModelArgs: def __init__( self, - transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], + transformer_args: Union[TransformerArgs, Dict[str, Dict[str, Any]]], model_type: ModelType = ModelType.TextOnly, ) -> None: self._sanity_check(transformer_args, model_type) @@ -275,7 +286,7 @@ def __init__( def _sanity_check( self, - transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], + transformer_args: Union[TransformerArgs, Dict[str, Dict[str, Any]]], model_type: ModelType, ) -> None: assert isinstance(model_type, ModelType) @@ -393,12 +404,20 @@ def build_model(self) -> nn.Module: modules = {} for name, module_class in recipe.modules.items(): if isinstance(config_args := self.config.transformer_args[name], dict): + config_args = self._replace_know_params(config_args) modules[name] = module_class(**config_args) else: modules[name] = module_class(config_args) return recipe.fusion_class(**modules) + def _replace_know_params(self, params): + patterns = {"QuickGELUActivation()": QuickGELUActivation(), "False": False, "True": True} + for key, value in params.items(): + if value in patterns: + params[key] = patterns[value] + return params + @abstractmethod def forward(self, *args, **kwargs): raise NotImplementedError("forward method is not implemented") diff --git a/torchchat/model_params/llava-1.5.json b/torchchat/model_params/llava-1.5.json index 6629b4dab..888f690a9 100644 --- a/torchchat/model_params/llava-1.5.json +++ b/torchchat/model_params/llava-1.5.json @@ -1,23 +1,24 @@ -@dataclass -class VisionArgs: - tile_size: int = 336 - patch_size: int = 14 - embed_dim: int = 1024 - num_layers: int = 24 - num_heads: int = 16 - out_indices: List[int] = field(default_factory=list) - output_cls_projection: bool = False - max_num_tiles: int = 1 - in_channels: int = 3 - intermediate_act: nn.Module = QuickGELUActivation() - - def __post_init__(self): - if not self.out_indices: - self.out_indices = [self.num_layers - 1] - - -@dataclass -class ProjectorArgs: - in_channels: int = 1024 - out_channels: int = 4096 - activation: nn.Module = nn.GELU() +{ + "model_type": "llava", + "encoder": { + "tile_size": 336, + "patch_size": 14, + "embed_dim": 1024, + "num_layers": 24, + "num_heads": 16, + "out_indices": [ + 23 + ], + "output_cls_projection": False, + "max_num_tiles": 1, + "in_channels": 3, + "intermediate_act": QuickGELUActivation() + }, + "decoder": { + "n_layers": 32, + "n_heads": 32, + "dim": 4096, + "vocab_size": 32064, + "max_seq_length": 768 + } +} From 23d65043bbacbb3f7a093bf638d04d10761e8263 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 16 Sep 2024 01:40:55 -0700 Subject: [PATCH 06/27] llava config update --- torchchat/model_params/llava-1.5.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchchat/model_params/llava-1.5.json b/torchchat/model_params/llava-1.5.json index 888f690a9..685904d9d 100644 --- a/torchchat/model_params/llava-1.5.json +++ b/torchchat/model_params/llava-1.5.json @@ -9,10 +9,10 @@ "out_indices": [ 23 ], - "output_cls_projection": False, + "output_cls_projection": "False", "max_num_tiles": 1, "in_channels": 3, - "intermediate_act": QuickGELUActivation() + "intermediate_act": "QuickGELUActivation()" }, "decoder": { "n_layers": 32, From 22fd2a5f0ab70dde5ef76f57d3cd9a23d8ca4e4b Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 16 Sep 2024 10:47:35 -0700 Subject: [PATCH 07/27] 4/n llava init --- torchchat/model.py | 87 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 83 insertions(+), 4 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index 7c89b53b2..e3af197fa 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -6,8 +6,12 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path +from PIL import Image +import requests +import torchvision from typing import Any, Callable, Dict, Optional, Union +from collections.abc import Hashable import torch import torch.nn as nn @@ -48,6 +52,13 @@ def identity(**kwargs): return list(kwargs.values())[0] +@dataclass +class ProjectorArgs: + in_channels: int = 1024 + out_channels: int = 4096 + activation: nn.Module = nn.GELU() + + class MultiModalProjector(nn.Module): def __init__(self, args: ProjectorArgs): super().__init__() @@ -105,13 +116,22 @@ def forward( encoder_output = self.encoder( encoder_input, ) + encoder_output = self._encoder_feature_select(encoder_output) else: encoder_output = None decoder_input = self._get_decoder_input( tokens, encoder_output=encoder_output, post_tokens=post_tokens ) - return self.decoder(decoder_input) + return self.decoder(decoder_input, input_pos=input_pos) + + def _encoder_feature_select(self, encoder_output): + selected_image_feature = encoder_output[1][0].view( + *encoder_output[1][0].shape[2:] + ) + + selected_image_feature = selected_image_feature[:, 1:] + return selected_image_feature def _get_decoder_input( self, @@ -197,8 +217,8 @@ def _llava(cls): return cls( model_type=ModelType.Llava, modules={ - 'te': flamingo_vision_encoder, - 'decoder': llama3_1_builder + 'encoder': clip_vision_encoder, + 'decoder': Transformer }, fusion_class=DeepFusionModel, ) @@ -414,7 +434,7 @@ def build_model(self) -> nn.Module: def _replace_know_params(self, params): patterns = {"QuickGELUActivation()": QuickGELUActivation(), "False": False, "True": True} for key, value in params.items(): - if value in patterns: + if isinstance(value, Hashable) and value in patterns: params[key] = patterns[value] return params @@ -496,10 +516,26 @@ def reset_caches(self): self.model.reset_caches() +class LlavaModel(Model): + def forward( + self, + tokens: Tensor, + *, + encoder_input: Optional[Dict[str, Tensor]] = None, + post_tokens: Optional[Tensor] = None, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos) + + def setup_caches(self, max_batch_size, max_seq_length): + self.model.setup_caches(max_batch_size, max_seq_length) + + MODEL_TYPE_TO_CLASS = { ModelType.TextOnly: TextOnlyModel, ModelType.Flamingo: FlamingoModel, ModelType.Llama3_1: Llama31Model, + ModelType.Llava: LlavaModel, } class Transformer(nn.Module): @@ -882,3 +918,46 @@ def setup_caches(self, max_batch_size, max_seq_length): except: pass + + +if __name__ == "__main__": + def prepare_image(target_h: int, target_w: int) -> torch.Tensor: + """Read image into a tensor and resize the image so that it fits in + a target_h x target_w canvas. + + Args: + image (Image): An Image object. + target_h (int): Target height. + target_w (int): Target width. + + Returns: + torch.Tensor: resized image tensor. + """ + image = Image.open( + requests.get( + "https://llava-vl.github.io/static/images/view.jpg", stream=True + ).raw) + + img = torchvision.transforms.functional.pil_to_tensor(image) + # height ratio + ratio_h = img.shape[1] / target_h + # width ratio + ratio_w = img.shape[2] / target_w + # resize the image so that it fits in a target_h x target_w canvas + ratio = max(ratio_h, ratio_w) + output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio)) + img = torchvision.transforms.Resize(size=output_size)(img) + return img + + pre_tokens = torch.tensor([[ 1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, + 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, + 322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155, + 29889, 3148, 1001, 29901, 29871]]) + img = prepare_image(336, 336) + post_tokens = torch.tensor([[29871, 13, 462, 9651, 1724, 526, 278, 2712, 306, 881, + 367, 274, 1300, 2738, 1048, 746, 306, 6493, 1244, 29973, + 319, 1799, 9047, 13566, 29901]]) + + llava_model = Model.from_params("/home/gasoonjia/torchchat/torchchat/model_params/llava-1.5.json") + + llava_model(tokens=pre_tokens, encoder_input=img, post_tokens=post_tokens) From fff864782889107ac574fed929844b71aa14cdfb Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 16 Sep 2024 16:52:11 -0700 Subject: [PATCH 08/27] unify model construction ppl --- torchchat/cli/builder.py | 7 +-- torchchat/generate.py | 38 +++++-------- torchchat/model.py | 53 +++++++++---------- .../model_params/Meta-Llama-3.1-70B-Tune.json | 1 + .../model_params/Meta-Llama-3.1-8B-Tune.json | 1 + 5 files changed, 43 insertions(+), 57 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index f600fe7f2..8ccdf6b57 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -240,12 +240,7 @@ def validate_model( is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece - text_args = model.config.transformer_args.get("text") - if text_args is None: - # TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo - use_tiktoken = model.config.model_type == ModelType.Flamingo - else: - use_tiktoken = text_args.use_tiktoken + use_tiktoken = model.config.use_tiktoken if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): raise RuntimeError( diff --git a/torchchat/generate.py b/torchchat/generate.py index cea9531e8..20b35c6be 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -27,12 +27,6 @@ from PIL import Image -# torchtune model definition dependencies -from torchtune.data import Message -from torchtune.generation._generation import sample as tune_sample -from torchtune.models.llama3 import llama3_tokenizer -from torchtune.training import set_default_dtype - from torchchat.cli.builder import ( _initialize_model, _initialize_tokenizer, @@ -43,6 +37,12 @@ from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info +# torchtune model definition dependencies +from torchtune.data import Message +from torchtune.generation._generation import sample as tune_sample +from torchtune.models.llama3 import llama3_tokenizer +from torchtune.training import set_default_dtype + class _ChatFormatter(ABC): def __init__(self, tokenizer): @@ -790,16 +790,12 @@ def chat( # This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong # TODO: unify the max_seq_length config representation. - if generator_args.is_torchtune_model: - max_seq_length = self.model.config.transformer_args.get("text", {}).get( - "max_seq_len", 2048 - ) - elif generator_args.chat_mode: - if ( - max_seq_length := self.model.config.transformer_args.get("text", None) - is None - ): - max_seq_length = 2048 + text_transformer_args = getattr(self.model.model, "config", None) + max_seq_length = ( + text_transformer_args.max_seq_length if text_transformer_args else 2048 + ) + + if generator_args.chat_mode: print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" ) @@ -809,15 +805,9 @@ def chat( if get_system_prompt == "y" or get_system_prompt == "Y": self.system_prompt = input("What is your system prompt? \n") - else: - text_transformer_args = self.model.config.transformer_args.get("text", None) + elif not generator_args.is_torchtune_model: max_seq_length = min( - encoded.size(0) + generator_args.max_new_tokens, - ( - text_transformer_args.block_size - if text_transformer_args is not None - else 2048 - ), + encoded.size(0) + generator_args.max_new_tokens, max_seq_length ) max_seq_length = ( diff --git a/torchchat/model.py b/torchchat/model.py index 10c036a36..7f5082da7 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -164,49 +164,49 @@ def from_params(cls, params): @dataclass class ModelArgs: model_type: ModelType - transformer_args: Dict[str, Union[Dict, TransformerArgs]] + transformer_args: Dict[str, Dict[str, Any]] + use_tiktoken: bool def __init__( self, - transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], + transformer_args: Dict[str, Dict[str, Any]], model_type: ModelType = ModelType.TextOnly, + use_tiktoken: bool = False, ) -> None: self._sanity_check(transformer_args, model_type) self.model_type = model_type - if isinstance(transformer_args, TransformerArgs): - assert model_type == ModelType.TextOnly - self.transformer_args = {"text": transformer_args} - else: - self.transformer_args = transformer_args + self.transformer_args = transformer_args + + # Model-level attributes + self.use_tiktoken = use_tiktoken def _sanity_check( self, - transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], + transformer_args: Dict[str, Dict[str, Any]], model_type: ModelType, ) -> None: - assert isinstance(model_type, ModelType) - assert isinstance(transformer_args, (TransformerArgs, dict)) + assert isinstance(model_type, ModelType), model_type + assert isinstance(transformer_args, dict) @classmethod def from_params(cls, params_path): with open(params_path, "r") as f: loaded_params = json.loads(f.read()) - - try: - # try to interpret as a single transformer config - transformer_args: Dict[str, TransformerArgs] = {} - transformer_args["text"] = TransformerArgs.from_params(loaded_params) - if (model_type := loaded_params.get("model_type", None)) is None: - model_type = ModelType.TextOnly - - except TypeError: - # try to interpret as a dict of transformer configs - model_type = ModelType(loaded_params["model_type"]) + + if (model_type_name := loaded_params.get("model_type", None)) is None: + # The model params is in the transformer_args format + # set the model_type to TextOnly and reformat the params + model_type = ModelType.TextOnly + transformer_args = {"text": {"config": loaded_params}} + else: + model_type = ModelType(model_type_name) transformer_args = { k: v for k, v in loaded_params.items() if k != "model_type" } - return cls(transformer_args, model_type) + + use_tiktoken = loaded_params.get("use_tiktoken", False) + return cls(transformer_args, model_type, use_tiktoken) @classmethod def from_table(cls, name: str): @@ -304,10 +304,8 @@ def build_model(self) -> nn.Module: recipe = ModelRecipe.get_recipe(self.config.model_type) modules = {} for name, module_class in recipe.modules.items(): - if isinstance(config_args := self.config.transformer_args[name], dict): - modules[name] = module_class(**config_args) - else: - modules[name] = module_class(config_args) + config_args = self.config.transformer_args[name] + modules[name] = module_class(**config_args) return recipe.fusion_class(**modules) @@ -399,8 +397,9 @@ def reset_caches(self): class Transformer(nn.Module): - def __init__(self, config: TransformerArgs) -> None: + def __init__(self, config: Dict[str, Any]) -> None: super().__init__() + config = TransformerArgs.from_params(config) self.config = config layers_per_stage = config.n_layers // config.n_stages diff --git a/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json b/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json index c59961c63..3c611a753 100644 --- a/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json +++ b/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json @@ -1,5 +1,6 @@ { "model_type": "llama3_1", + "use_tiktoken": true, "text": { "vocab_size": 128256, "num_layers": 80, diff --git a/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json b/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json index e9ded77bd..adc9e4e8e 100644 --- a/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json +++ b/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json @@ -1,5 +1,6 @@ { "model_type": "llama3_1", + "use_tiktoken": true, "text": { "vocab_size": 128256, "num_layers": 32, From 4b666a77de305e671c6c3d2b931a8682b2b0286f Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 16 Sep 2024 16:59:29 -0700 Subject: [PATCH 09/27] update transformer config --- distributed/parallelize_llama.py | 2 +- torchchat/cli/builder.py | 2 +- torchchat/export.py | 2 +- torchchat/usages/eval.py | 2 +- torchchat/usages/openai_api.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index 5f358865d..9f1338b9d 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -62,7 +62,7 @@ def apply_tp( # after we apply TP to the model. Because we don't want to change model code # when applying TP. We need to have change to ensure KVCache has the correct # size as k and v. - model.config.transformer_args["text"].n_local_heads = model.config.transformer_args["text"].n_local_heads // tp_mesh.size() + model.model.config.n_local_heads = model.model.config.n_local_heads // tp_mesh.size() # Apply tensor parallelism to every transformer block for transformer_block in model.layers: diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 8ccdf6b57..da659d236 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -563,7 +563,7 @@ def _initialize_model( model.setup_caches( max_batch_size=1, max_seq_length=max_seq_length - or model.config.transformer_args["text"].max_seq_length, + or model.model.config.max_seq_length, ) model.to(dtype=builder_args.precision) diff --git a/torchchat/export.py b/torchchat/export.py index efc791dc8..d5dbb8fd6 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -54,7 +54,7 @@ def export_for_server( torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), ) - seq = Dim("seq", min=1, max=model.config.transformer_args["text"].max_seq_length) + seq = Dim("seq", min=1, max=model.model.config.max_seq_length) # Specify that the first dimension of each input is that batch size dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}} else: diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index f8ac6fbe1..edf439c66 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -59,7 +59,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( T = prompt.size(0) T_new = T + max_new_tokens if max_seq_length is None: - max_seq_length = min(T_new, model.config.transformer_args["text"].block_size) + max_seq_length = min(T_new, model.model.config.block_size) device, dtype = prompt.device, prompt.dtype # create an empty tensor of the expected final shape and diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index 8f4367453..448777c24 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -233,11 +233,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_seq_length = ( - self.model.config.transformer_args["text"].max_seq_length + self.model.model.config.max_seq_length + self.speculative_builder_args.speculate_k + 1 if self.draft_model is not None - else self.model.config.transformer_args["text"].max_seq_length + else self.model.model.config.max_seq_length ) # The System fingerprint is a unique identifier for the model and its configuration. self.system_fingerprint = ( From cc8b4d654d865cdee1508590d83ad953418323f4 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 16 Sep 2024 17:14:30 -0700 Subject: [PATCH 10/27] update model config for gguf --- torchchat/utils/gguf_loader.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/torchchat/utils/gguf_loader.py b/torchchat/utils/gguf_loader.py index c7b931dae..8fdadf5bf 100644 --- a/torchchat/utils/gguf_loader.py +++ b/torchchat/utils/gguf_loader.py @@ -542,15 +542,19 @@ def load_model(gguf_file: str) -> torch.nn.Module: assert arch == "llama", "Only LLaMa models are supported by this converter." model_args = ModelArgs( - TransformerArgs( - dim=metadata[f"{arch}.embedding_length"], - n_layers=metadata[f"{arch}.block_count"], - n_heads=metadata[f"{arch}.attention.head_count"], - n_local_heads=metadata[f"{arch}.attention.head_count_kv"], - vocab_size=len(metadata["tokenizer.ggml.tokens"]), - norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"], - hidden_dim=metadata[f"{arch}.feed_forward_length"], - ) + { + "text": { + "config": { + "dim": metadata[f"{arch}.embedding_length"], + "n_layers": metadata[f"{arch}.block_count"], + "n_heads": metadata[f"{arch}.attention.head_count"], + "n_local_heads": metadata[f"{arch}.attention.head_count_kv"], + "vocab_size": len(metadata["tokenizer.ggml.tokens"]), + "norm_eps": metadata[f"{arch}.attention.layer_norm_rms_epsilon"], + "hidden_dim": metadata[f"{arch}.feed_forward_length"], + } + } + } ) # TODO: what to do with rope args like From 7ec018aca4671de2710a80f50140811988f9993d Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 16 Sep 2024 17:32:08 -0700 Subject: [PATCH 11/27] hack PTEModel to have same config hirearchy as Model --- torchchat/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchchat/model.py b/torchchat/model.py index 7f5082da7..680499878 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -779,6 +779,10 @@ def __init__(self, config, path) -> None: super().__init__() self.config = config self.model_ = exec_lib._load_for_executorch(str(path)) + + # A hacky way to get the model config from the self.model, making it consistent with Model class + # TODO: remove the hacky way once get rid of model.model + self.model = type('model', (), {'config': self.config}) def forward(self, x, input_pos): # model_.forward expects inputs to be wrapped in a tuple From 94e56f1cfe956ea4163d8cca1157acb1873633d8 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 16 Sep 2024 16:52:11 -0700 Subject: [PATCH 12/27] unify model construction ppl --- torchchat/cli/builder.py | 7 +-- torchchat/generate.py | 38 +++++-------- torchchat/model.py | 53 +++++++++---------- .../model_params/Meta-Llama-3.1-70B-Tune.json | 1 + .../model_params/Meta-Llama-3.1-8B-Tune.json | 1 + 5 files changed, 43 insertions(+), 57 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index f600fe7f2..8ccdf6b57 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -240,12 +240,7 @@ def validate_model( is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece - text_args = model.config.transformer_args.get("text") - if text_args is None: - # TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo - use_tiktoken = model.config.model_type == ModelType.Flamingo - else: - use_tiktoken = text_args.use_tiktoken + use_tiktoken = model.config.use_tiktoken if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): raise RuntimeError( diff --git a/torchchat/generate.py b/torchchat/generate.py index cea9531e8..20b35c6be 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -27,12 +27,6 @@ from PIL import Image -# torchtune model definition dependencies -from torchtune.data import Message -from torchtune.generation._generation import sample as tune_sample -from torchtune.models.llama3 import llama3_tokenizer -from torchtune.training import set_default_dtype - from torchchat.cli.builder import ( _initialize_model, _initialize_tokenizer, @@ -43,6 +37,12 @@ from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info +# torchtune model definition dependencies +from torchtune.data import Message +from torchtune.generation._generation import sample as tune_sample +from torchtune.models.llama3 import llama3_tokenizer +from torchtune.training import set_default_dtype + class _ChatFormatter(ABC): def __init__(self, tokenizer): @@ -790,16 +790,12 @@ def chat( # This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong # TODO: unify the max_seq_length config representation. - if generator_args.is_torchtune_model: - max_seq_length = self.model.config.transformer_args.get("text", {}).get( - "max_seq_len", 2048 - ) - elif generator_args.chat_mode: - if ( - max_seq_length := self.model.config.transformer_args.get("text", None) - is None - ): - max_seq_length = 2048 + text_transformer_args = getattr(self.model.model, "config", None) + max_seq_length = ( + text_transformer_args.max_seq_length if text_transformer_args else 2048 + ) + + if generator_args.chat_mode: print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" ) @@ -809,15 +805,9 @@ def chat( if get_system_prompt == "y" or get_system_prompt == "Y": self.system_prompt = input("What is your system prompt? \n") - else: - text_transformer_args = self.model.config.transformer_args.get("text", None) + elif not generator_args.is_torchtune_model: max_seq_length = min( - encoded.size(0) + generator_args.max_new_tokens, - ( - text_transformer_args.block_size - if text_transformer_args is not None - else 2048 - ), + encoded.size(0) + generator_args.max_new_tokens, max_seq_length ) max_seq_length = ( diff --git a/torchchat/model.py b/torchchat/model.py index 10c036a36..7f5082da7 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -164,49 +164,49 @@ def from_params(cls, params): @dataclass class ModelArgs: model_type: ModelType - transformer_args: Dict[str, Union[Dict, TransformerArgs]] + transformer_args: Dict[str, Dict[str, Any]] + use_tiktoken: bool def __init__( self, - transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], + transformer_args: Dict[str, Dict[str, Any]], model_type: ModelType = ModelType.TextOnly, + use_tiktoken: bool = False, ) -> None: self._sanity_check(transformer_args, model_type) self.model_type = model_type - if isinstance(transformer_args, TransformerArgs): - assert model_type == ModelType.TextOnly - self.transformer_args = {"text": transformer_args} - else: - self.transformer_args = transformer_args + self.transformer_args = transformer_args + + # Model-level attributes + self.use_tiktoken = use_tiktoken def _sanity_check( self, - transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], + transformer_args: Dict[str, Dict[str, Any]], model_type: ModelType, ) -> None: - assert isinstance(model_type, ModelType) - assert isinstance(transformer_args, (TransformerArgs, dict)) + assert isinstance(model_type, ModelType), model_type + assert isinstance(transformer_args, dict) @classmethod def from_params(cls, params_path): with open(params_path, "r") as f: loaded_params = json.loads(f.read()) - - try: - # try to interpret as a single transformer config - transformer_args: Dict[str, TransformerArgs] = {} - transformer_args["text"] = TransformerArgs.from_params(loaded_params) - if (model_type := loaded_params.get("model_type", None)) is None: - model_type = ModelType.TextOnly - - except TypeError: - # try to interpret as a dict of transformer configs - model_type = ModelType(loaded_params["model_type"]) + + if (model_type_name := loaded_params.get("model_type", None)) is None: + # The model params is in the transformer_args format + # set the model_type to TextOnly and reformat the params + model_type = ModelType.TextOnly + transformer_args = {"text": {"config": loaded_params}} + else: + model_type = ModelType(model_type_name) transformer_args = { k: v for k, v in loaded_params.items() if k != "model_type" } - return cls(transformer_args, model_type) + + use_tiktoken = loaded_params.get("use_tiktoken", False) + return cls(transformer_args, model_type, use_tiktoken) @classmethod def from_table(cls, name: str): @@ -304,10 +304,8 @@ def build_model(self) -> nn.Module: recipe = ModelRecipe.get_recipe(self.config.model_type) modules = {} for name, module_class in recipe.modules.items(): - if isinstance(config_args := self.config.transformer_args[name], dict): - modules[name] = module_class(**config_args) - else: - modules[name] = module_class(config_args) + config_args = self.config.transformer_args[name] + modules[name] = module_class(**config_args) return recipe.fusion_class(**modules) @@ -399,8 +397,9 @@ def reset_caches(self): class Transformer(nn.Module): - def __init__(self, config: TransformerArgs) -> None: + def __init__(self, config: Dict[str, Any]) -> None: super().__init__() + config = TransformerArgs.from_params(config) self.config = config layers_per_stage = config.n_layers // config.n_stages diff --git a/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json b/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json index c59961c63..3c611a753 100644 --- a/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json +++ b/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json @@ -1,5 +1,6 @@ { "model_type": "llama3_1", + "use_tiktoken": true, "text": { "vocab_size": 128256, "num_layers": 80, diff --git a/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json b/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json index e9ded77bd..adc9e4e8e 100644 --- a/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json +++ b/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json @@ -1,5 +1,6 @@ { "model_type": "llama3_1", + "use_tiktoken": true, "text": { "vocab_size": 128256, "num_layers": 32, From 43dfdc71e49b94001c6a4c39fdf907ece86b3366 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 16 Sep 2024 17:54:38 -0700 Subject: [PATCH 13/27] 5/n torchchat init --- torchchat/model.py | 2 +- torchchat/model_params/llava-1.5.json | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index 5c4cd2838..ce7b3f18e 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -438,7 +438,7 @@ def build_model(self) -> nn.Module: return recipe.fusion_class(**modules) def _replace_know_params(self, params): - patterns = {"QuickGELUActivation()": QuickGELUActivation(), "False": False, "True": True} + patterns = {"QuickGELUActivation()": QuickGELUActivation()} for key, value in params.items(): if isinstance(value, Hashable) and value in patterns: params[key] = patterns[value] diff --git a/torchchat/model_params/llava-1.5.json b/torchchat/model_params/llava-1.5.json index 685904d9d..992cc2c69 100644 --- a/torchchat/model_params/llava-1.5.json +++ b/torchchat/model_params/llava-1.5.json @@ -1,5 +1,6 @@ { "model_type": "llava", + "use_tiktoken": true, "encoder": { "tile_size": 336, "patch_size": 14, @@ -9,7 +10,7 @@ "out_indices": [ 23 ], - "output_cls_projection": "False", + "output_cls_projection": false, "max_num_tiles": 1, "in_channels": 3, "intermediate_act": "QuickGELUActivation()" From 63d76a1dd81ca4ead3cb5695acf7cac28497cb89 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 16 Sep 2024 18:04:07 -0700 Subject: [PATCH 14/27] hack PTEModel to support current ppl --- torchchat/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchchat/model.py b/torchchat/model.py index 680499878..131c5bc03 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -782,7 +782,11 @@ def __init__(self, config, path) -> None: # A hacky way to get the model config from the self.model, making it consistent with Model class # TODO: remove the hacky way once get rid of model.model - self.model = type('model', (), {'config': self.config}) + try: + text_transformer_config = TransformerArgs.from_params(self.config.transformer_args["text"]) + else: + text_transformer_config = None + self.model = type('model', (), {'config': text_transformer_config}) def forward(self, x, input_pos): # model_.forward expects inputs to be wrapped in a tuple From 01bb624cc30ad7659b713af13fa0b8c5aaa17661 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 16 Sep 2024 18:22:35 -0700 Subject: [PATCH 15/27] fix a typo --- torchchat/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchchat/model.py b/torchchat/model.py index 131c5bc03..380b42569 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -784,7 +784,7 @@ def __init__(self, config, path) -> None: # TODO: remove the hacky way once get rid of model.model try: text_transformer_config = TransformerArgs.from_params(self.config.transformer_args["text"]) - else: + except: text_transformer_config = None self.model = type('model', (), {'config': text_transformer_config}) From 319ac86bb01a524e153cf6268e71d101e36d8803 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 16 Sep 2024 16:52:11 -0700 Subject: [PATCH 16/27] unify model construction ppl --- torchchat/cli/builder.py | 7 +-- torchchat/generate.py | 38 +++++-------- torchchat/model.py | 53 +++++++++---------- .../model_params/Meta-Llama-3.1-70B-Tune.json | 1 + .../model_params/Meta-Llama-3.1-8B-Tune.json | 1 + 5 files changed, 43 insertions(+), 57 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index f600fe7f2..8ccdf6b57 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -240,12 +240,7 @@ def validate_model( is_tiktoken = self.is_tiktoken is_sentencepiece = self.is_sentencepiece - text_args = model.config.transformer_args.get("text") - if text_args is None: - # TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo - use_tiktoken = model.config.model_type == ModelType.Flamingo - else: - use_tiktoken = text_args.use_tiktoken + use_tiktoken = model.config.use_tiktoken if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): raise RuntimeError( diff --git a/torchchat/generate.py b/torchchat/generate.py index 30490d396..6f6b3a082 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -27,12 +27,6 @@ from PIL import Image -# torchtune model definition dependencies -from torchtune.data import Message -from torchtune.generation._generation import sample as tune_sample -from torchtune.models.llama3 import llama3_tokenizer -from torchtune.training import set_default_dtype - from torchchat.cli.builder import ( _initialize_model, _initialize_tokenizer, @@ -43,6 +37,12 @@ from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info +# torchtune model definition dependencies +from torchtune.data import Message +from torchtune.generation._generation import sample as tune_sample +from torchtune.models.llama3 import llama3_tokenizer +from torchtune.training import set_default_dtype + class _ChatFormatter(ABC): def __init__(self, tokenizer): @@ -795,16 +795,12 @@ def chat( # This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong # TODO: unify the max_seq_length config representation. - if generator_args.is_torchtune_model: - max_seq_length = self.model.config.transformer_args.get("text", {}).get( - "max_seq_len", 2048 - ) - elif generator_args.chat_mode: - if ( - max_seq_length := self.model.config.transformer_args.get("text", None) - is None - ): - max_seq_length = 2048 + text_transformer_args = getattr(self.model.model, "config", None) + max_seq_length = ( + text_transformer_args.max_seq_length if text_transformer_args else 2048 + ) + + if generator_args.chat_mode: print( f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye" ) @@ -814,15 +810,9 @@ def chat( if get_system_prompt == "y" or get_system_prompt == "Y": self.system_prompt = input("What is your system prompt? \n") - else: - text_transformer_args = self.model.config.transformer_args.get("text", None) + elif not generator_args.is_torchtune_model: max_seq_length = min( - encoded.size(0) + generator_args.max_new_tokens, - ( - text_transformer_args.block_size - if text_transformer_args is not None - else 2048 - ), + encoded.size(0) + generator_args.max_new_tokens, max_seq_length ) max_seq_length = ( diff --git a/torchchat/model.py b/torchchat/model.py index 10c036a36..7f5082da7 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -164,49 +164,49 @@ def from_params(cls, params): @dataclass class ModelArgs: model_type: ModelType - transformer_args: Dict[str, Union[Dict, TransformerArgs]] + transformer_args: Dict[str, Dict[str, Any]] + use_tiktoken: bool def __init__( self, - transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], + transformer_args: Dict[str, Dict[str, Any]], model_type: ModelType = ModelType.TextOnly, + use_tiktoken: bool = False, ) -> None: self._sanity_check(transformer_args, model_type) self.model_type = model_type - if isinstance(transformer_args, TransformerArgs): - assert model_type == ModelType.TextOnly - self.transformer_args = {"text": transformer_args} - else: - self.transformer_args = transformer_args + self.transformer_args = transformer_args + + # Model-level attributes + self.use_tiktoken = use_tiktoken def _sanity_check( self, - transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], + transformer_args: Dict[str, Dict[str, Any]], model_type: ModelType, ) -> None: - assert isinstance(model_type, ModelType) - assert isinstance(transformer_args, (TransformerArgs, dict)) + assert isinstance(model_type, ModelType), model_type + assert isinstance(transformer_args, dict) @classmethod def from_params(cls, params_path): with open(params_path, "r") as f: loaded_params = json.loads(f.read()) - - try: - # try to interpret as a single transformer config - transformer_args: Dict[str, TransformerArgs] = {} - transformer_args["text"] = TransformerArgs.from_params(loaded_params) - if (model_type := loaded_params.get("model_type", None)) is None: - model_type = ModelType.TextOnly - - except TypeError: - # try to interpret as a dict of transformer configs - model_type = ModelType(loaded_params["model_type"]) + + if (model_type_name := loaded_params.get("model_type", None)) is None: + # The model params is in the transformer_args format + # set the model_type to TextOnly and reformat the params + model_type = ModelType.TextOnly + transformer_args = {"text": {"config": loaded_params}} + else: + model_type = ModelType(model_type_name) transformer_args = { k: v for k, v in loaded_params.items() if k != "model_type" } - return cls(transformer_args, model_type) + + use_tiktoken = loaded_params.get("use_tiktoken", False) + return cls(transformer_args, model_type, use_tiktoken) @classmethod def from_table(cls, name: str): @@ -304,10 +304,8 @@ def build_model(self) -> nn.Module: recipe = ModelRecipe.get_recipe(self.config.model_type) modules = {} for name, module_class in recipe.modules.items(): - if isinstance(config_args := self.config.transformer_args[name], dict): - modules[name] = module_class(**config_args) - else: - modules[name] = module_class(config_args) + config_args = self.config.transformer_args[name] + modules[name] = module_class(**config_args) return recipe.fusion_class(**modules) @@ -399,8 +397,9 @@ def reset_caches(self): class Transformer(nn.Module): - def __init__(self, config: TransformerArgs) -> None: + def __init__(self, config: Dict[str, Any]) -> None: super().__init__() + config = TransformerArgs.from_params(config) self.config = config layers_per_stage = config.n_layers // config.n_stages diff --git a/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json b/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json index c59961c63..3c611a753 100644 --- a/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json +++ b/torchchat/model_params/Meta-Llama-3.1-70B-Tune.json @@ -1,5 +1,6 @@ { "model_type": "llama3_1", + "use_tiktoken": true, "text": { "vocab_size": 128256, "num_layers": 80, diff --git a/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json b/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json index e9ded77bd..adc9e4e8e 100644 --- a/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json +++ b/torchchat/model_params/Meta-Llama-3.1-8B-Tune.json @@ -1,5 +1,6 @@ { "model_type": "llama3_1", + "use_tiktoken": true, "text": { "vocab_size": 128256, "num_layers": 32, From 8cd093691fd857f47f4f89f51e7c882633f9bd11 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 17 Sep 2024 00:02:18 -0700 Subject: [PATCH 17/27] bring TransformerArgs back to Transformer --- torchchat/model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index fc95c99cc..c5aebd6bf 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -210,7 +210,7 @@ def from_params(cls, params_path): # The model params is in the transformer_args format # set the model_type to TextOnly and reformat the params model_type = ModelType.TextOnly - transformer_args = {"text": {"config": loaded_params}} + transformer_args = {"text": loaded_params} else: model_type = ModelType(model_type_name) transformer_args = { @@ -317,7 +317,10 @@ def build_model(self) -> nn.Module: modules = {} for name, module_class in recipe.modules.items(): config_args = self.config.transformer_args[name] - modules[name] = module_class(**config_args) + if module_class == Transformer: + modules[name] = module_class(TransformerArgs.from_params(config_args)) + else: + modules[name] = module_class(**config_args) return recipe.fusion_class(**modules) @@ -424,9 +427,8 @@ def get_text_transformer_args(self): class Transformer(nn.Module): - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: TransformerArgs) -> None: super().__init__() - config = TransformerArgs.from_params(config) self.config = config layers_per_stage = config.n_layers // config.n_stages From 304fece7419b38ffd35fff2f0b454729187f5a1e Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 17 Sep 2024 00:13:40 -0700 Subject: [PATCH 18/27] rename get_text_transformer_args as text_transformer_args for readibility --- _torchchat_test_script.py | 293 ------------------------------- distributed/parallelize_llama.py | 2 +- torchchat/cli/builder.py | 2 +- torchchat/export.py | 2 +- torchchat/model.py | 27 ++- torchchat/usages/eval.py | 2 +- torchchat/usages/openai_api.py | 4 +- 7 files changed, 18 insertions(+), 314 deletions(-) delete mode 100644 _torchchat_test_script.py diff --git a/_torchchat_test_script.py b/_torchchat_test_script.py deleted file mode 100644 index 686e6b3a7..000000000 --- a/_torchchat_test_script.py +++ /dev/null @@ -1,293 +0,0 @@ - -import torch -import sys -import os - -from torchtune import training -from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder, FlamingoTransform -from torchtune.modules.model_fusion import DeepFusionModel - -from torchchat.model import Model - -import re - -from typing import Dict -from torchtune.generation._generation import sample -from torchtune.training import set_default_dtype -import numpy as np -import PIL - -from torchtune.data import Message - -def flamingo_transform(tokenizer_path): - return FlamingoTransform( - tokenizer_path, - tile_size=448, - patch_size=14, - max_num_tiles=4, - max_seq_len=8192, - encoder_max_seq_len=4100, - image_mean=(0.48145466, 0.4578275, 0.40821073), - image_std=(0.26862954, 0.26130258, 0.27577711), - prompt_template=None, - ) - -def padded_collate(batch, device='cuda', dtype=torch.bfloat16, padding_idx=0): - # Placeholder Collator until https://github.com/pytorch/torchtune/pull/1156 lands - assert len(batch) == 1, "Test collate function only supports bs = 1" - sample = batch[0] - sample["tokens"] = torch.Tensor(sample["tokens"])[None, ...].to(device).long() - sample["mask"] = torch.Tensor(sample["mask"])[None, ...].to(device).bool() - sample["encoder_input"]["images"] = torch.stack(sample["encoder_input"]["images"])[None, ...].to(device) - sample["encoder_input"]["aspect_ratio"] = torch.stack(sample["encoder_input"]["aspect_ratio"])[None, ...].to(device) - assert len(sample["encoder_mask"]), "Test collate function only supports 1 image per sequence" - # Pad encoder mask to max_num_tiles sequence length (4100) - s_x, s_y = sample["encoder_mask"][0].shape - mask_padding = torch.zeros((s_x, 4100 - s_y), dtype=torch.bool) - encoder_mask = torch.cat([sample["encoder_mask"][0], mask_padding], dim=1) - sample["encoder_mask"] = encoder_mask[None, ...].to(device) - return sample - - - -_FROM_META = { - "text_model.tok_embeddings.weight": "decoder.tok_embeddings.weight", - "text_model.learnable_embedding.weight": "decoder.tok_embeddings.fusion_embedding.weight", - "text_model.norm.weight": "decoder.norm.scale", - "text_model.output.weight": "decoder.output.weight", - - "text_model.layers.{}.attention_norm.weight": "decoder.layers.{}.sa_norm.scale", - "text_model.layers.{}.attention.wq.weight": "decoder.layers.{}.attn.q_proj.weight", - "text_model.layers.{}.attention.wk.weight": "decoder.layers.{}.attn.k_proj.weight", - "text_model.layers.{}.attention.wv.weight": "decoder.layers.{}.attn.v_proj.weight", - "text_model.layers.{}.attention.wo.weight": "decoder.layers.{}.attn.output_proj.weight", - "text_model.layers.{}.ffn_norm.weight": "decoder.layers.{}.mlp_norm.scale", - "text_model.layers.{}.feed_forward.w1.weight": "decoder.layers.{}.mlp.w1.weight", - "text_model.layers.{}.feed_forward.w3.weight": "decoder.layers.{}.mlp.w3.weight", - "text_model.layers.{}.feed_forward.w2.weight": "decoder.layers.{}.mlp.w2.weight", - - "text_model.cross_attention_layers.{}.gate_attn": "decoder.layers.{}.fusion_layer.ca_scale.scale", - "text_model.cross_attention_layers.{}.gate_ffwd": "decoder.layers.{}.fusion_layer.mlp_scale.scale", - "text_model.cross_attention_layers.{}.attention_norm.weight": "decoder.layers.{}.fusion_layer.ca_norm.scale", - "text_model.cross_attention_layers.{}.ffn_norm.weight": "decoder.layers.{}.fusion_layer.mlp_norm.scale", - "text_model.cross_attention_layers.{}.attention.wq.weight": "decoder.layers.{}.fusion_layer.attn.q_proj.weight", - "text_model.cross_attention_layers.{}.attention.wk.weight": "decoder.layers.{}.fusion_layer.attn.k_proj.weight", - "text_model.cross_attention_layers.{}.attention.wv.weight": "decoder.layers.{}.fusion_layer.attn.v_proj.weight", - "text_model.cross_attention_layers.{}.attention.wo.weight": "decoder.layers.{}.fusion_layer.attn.output_proj.weight", - "text_model.cross_attention_layers.{}.attention.inner_attention.q_norm.weight": "decoder.layers.{}.fusion_layer.attn.q_norm.scale", - "text_model.cross_attention_layers.{}.attention.inner_attention.k_norm.weight": "decoder.layers.{}.fusion_layer.attn.k_norm.scale", - "text_model.cross_attention_layers.{}.feed_forward.w1.weight": "decoder.layers.{}.fusion_layer.mlp.w1.weight", - "text_model.cross_attention_layers.{}.feed_forward.w3.weight": "decoder.layers.{}.fusion_layer.mlp.w3.weight", - "text_model.cross_attention_layers.{}.feed_forward.w2.weight": "decoder.layers.{}.fusion_layer.mlp.w2.weight", - - "vision_model.vision_encoder.positional_embedding": "encoder.clip.token_pos_embedding.local_token_positional_embedding", - "vision_model.vision_encoder.gated_positional_embedding": "encoder.clip.token_pos_embedding.global_token_positional_embedding", - "vision_model.vision_encoder.gated_positional_embedding_gate": "encoder.clip.token_pos_embedding.gate", - "vision_model.vision_encoder.ln_pre.weight": "encoder.clip.ln_pre.weight", - "vision_model.vision_encoder.ln_pre.bias": "encoder.clip.ln_pre.bias", - "vision_model.vision_encoder.ln_post.weight": "encoder.clip.ln_post.weight", - "vision_model.vision_encoder.ln_post.bias": "encoder.clip.ln_post.bias", - "vision_model.vision_encoder.pre_tile_pos_embed.embedding": "encoder.clip.pre_tile_pos_embed.embedding", - "vision_model.vision_encoder.pre_tile_pos_embed.gate": "encoder.clip.pre_tile_pos_embed.gate", - "vision_model.vision_encoder.post_tile_pos_embed.embedding": "encoder.clip.post_tile_pos_embed.embedding", - "vision_model.vision_encoder.post_tile_pos_embed.gate": "encoder.clip.post_tile_pos_embed.gate", - "vision_model.vision_encoder.class_embedding" : "encoder.clip.cls_token_embedding.weight", - "vision_model.vision_encoder.conv1._linear.weight" : "encoder.clip.conv.weight", - - "vision_model.vision_encoder.transformer.resblocks.{}.attn.wq.weight": "encoder.clip.layers.{}.attn.q_proj.weight", - "vision_model.vision_encoder.transformer.resblocks.{}.attn.wk.weight": "encoder.clip.layers.{}.attn.k_proj.weight", - "vision_model.vision_encoder.transformer.resblocks.{}.attn.wv.weight": "encoder.clip.layers.{}.attn.v_proj.weight", - "vision_model.vision_encoder.transformer.resblocks.{}.attn.wo.weight": "encoder.clip.layers.{}.attn.output_proj.weight", - "vision_model.vision_encoder.transformer.resblocks.{}.mlp.c_fc.weight": "encoder.clip.layers.{}.mlp.w1.weight", - "vision_model.vision_encoder.transformer.resblocks.{}.mlp.c_fc.bias": "encoder.clip.layers.{}.mlp.w1.bias", - "vision_model.vision_encoder.transformer.resblocks.{}.mlp.c_proj.weight": "encoder.clip.layers.{}.mlp.w2.weight", - "vision_model.vision_encoder.transformer.resblocks.{}.mlp.c_proj.bias": "encoder.clip.layers.{}.mlp.w2.bias", - "vision_model.vision_encoder.transformer.resblocks.{}.ln_1.weight": "encoder.clip.layers.{}.sa_norm.weight", - "vision_model.vision_encoder.transformer.resblocks.{}.ln_1.bias": "encoder.clip.layers.{}.sa_norm.bias", - "vision_model.vision_encoder.transformer.resblocks.{}.ln_2.weight": "encoder.clip.layers.{}.mlp_norm.weight", - "vision_model.vision_encoder.transformer.resblocks.{}.ln_2.bias": "encoder.clip.layers.{}.mlp_norm.bias", - - "vision_model.vision_projection.weight" : "encoder.projection.output.weight", - "vision_model.vision_projection.bias" : "encoder.projection.output.bias", - - "vision_model.vision_encoder.global_transformer.resblocks.{}.attn.wq.weight": "encoder.projection.layers.{}.attn.q_proj.weight", - "vision_model.vision_encoder.global_transformer.resblocks.{}.attn.wk.weight": "encoder.projection.layers.{}.attn.k_proj.weight", - "vision_model.vision_encoder.global_transformer.resblocks.{}.attn.wv.weight": "encoder.projection.layers.{}.attn.v_proj.weight", - "vision_model.vision_encoder.global_transformer.resblocks.{}.attn.wo.weight": "encoder.projection.layers.{}.attn.output_proj.weight", - "vision_model.vision_encoder.global_transformer.resblocks.{}.mlp.c_fc.weight": "encoder.projection.layers.{}.mlp.w1.weight", - "vision_model.vision_encoder.global_transformer.resblocks.{}.mlp.c_fc.bias": "encoder.projection.layers.{}.mlp.w1.bias", - "vision_model.vision_encoder.global_transformer.resblocks.{}.mlp.c_proj.weight": "encoder.projection.layers.{}.mlp.w2.weight", - "vision_model.vision_encoder.global_transformer.resblocks.{}.mlp.c_proj.bias": "encoder.projection.layers.{}.mlp.w2.bias", - "vision_model.vision_encoder.global_transformer.resblocks.{}.ln_1.weight": "encoder.projection.layers.{}.sa_norm.weight", - "vision_model.vision_encoder.global_transformer.resblocks.{}.ln_1.bias": "encoder.projection.layers.{}.sa_norm.bias", - "vision_model.vision_encoder.global_transformer.resblocks.{}.ln_2.weight": "encoder.projection.layers.{}.mlp_norm.weight", - "vision_model.vision_encoder.global_transformer.resblocks.{}.ln_2.bias": "encoder.projection.layers.{}.mlp_norm.bias", - "vision_model.vision_encoder.global_transformer.resblocks.{}.gate_attn": "encoder.projection.layers.{}.sa_scale.scale", - "vision_model.vision_encoder.global_transformer.resblocks.{}.gate_ffn": "encoder.projection.layers.{}.mlp_scale.scale", -} - - -def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str: - try: - if any(k.isdigit() for k in key.split(".")): - # Replace layer number with "{}" to create key for lookup - abstract_key = re.sub(r"(\.\d+)", ".{}", key) - layer_num = re.search(r"\d+", key).group(0) - new_key = mapping_dict[abstract_key] - new_key = new_key.format(layer_num) - else: - new_key = mapping_dict[key] - except KeyError as e: - raise Exception( - f'Error converting the state dict. Found unexpected key: "{key}". ' - "Please make sure you're loading a checkpoint with the right format. " - ) from e - - return new_key - - -def flamingo_meta_to_tune(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - Convertor from Meta state dict to torchtune state dict. This handles: - - Updateing the cross attention layer numbers - """ - converted_state_dict = {} - - for key, value in state_dict.items(): - if key == "text_model.rope.freqs": - continue - new_key = get_mapped_key(key, _FROM_META) - if "cross_attention_layers" in key: - layer = int(key.split(".")[2]) - # TODO: grab num_layers and generalize this - new_layer = (layer + 1) * 4 - 1 - key_lst = new_key.split(".") - key_lst[2] = str(new_layer) - new_key = ".".join(key_lst) - if "gate_ffwd" in key or "gate_attn" in key: - value = value[:1] - elif "conv1" in key: - # TODO: get patch size and generalize - value = value.reshape(-1, 3, 14, 14) - converted_state_dict[new_key] = value - return converted_state_dict - - - -if __name__ == "__main__": - llava3_2_dir = str(sys.argv[1]) - param_path = os.path.join(llava3_2_dir, "flamingo.json") - tokenizer_path = os.path.join(llava3_2_dir, "tokenizer.model") - checkpoint_path = os.path.join(llava3_2_dir, "consolidated.pth") - image_path = os.path.join(llava3_2_dir, "dog.jpg") - - if len(sys.argv) > 2: - device = torch.device(str(sys.argv[2])) - elif torch.cuda.is_available(): - device = torch.device('cuda:0') - else: - device = torch.device("cpu") - - print(f"Using device: {device}") - print(f"Loading model from {param_path}") - - dtype = torch.bfloat16 - with set_default_dtype(dtype), device: - model = Model.from_params(param_path) - - transform = flamingo_transform(tokenizer_path) - - print(f"Loading checkpoint from {checkpoint_path}") - state_dict = torch.load(checkpoint_path) - print("Converting state dict into flamingo format") - state_dict = flamingo_meta_to_tune(state_dict) - print("Loading state dict into model") - model.model.load_state_dict(state_dict) - - model = torch.compile(model) - images = [PIL.Image.open(image_path)] - - dialog = [ - Message( - role="user", - content=[ - {"type": "image"}, - {"type": "text", "content": "What's in this image?"}, - ], - eot=True, - ), - Message(role="assistant", content="") - ] - - data = transform({"images": images, "messages": dialog}, inference=True) - - model.eval() - with device: - model.setup_caches(1, dtype=torch.bfloat16) - - - max_generated_tokens = 100 - temperature = .6 - top_k = 500 - - print("Generating...") - - generated_tokens = [] - model.reset_caches() - with torch.no_grad(): - batch = padded_collate([data], device, dtype) - batch.pop("mask") - - logits = model(**batch)[:, -1] - tok = sample(logits, temperature, top_k) - generated_tokens.append(tok.item()) - - cache_mask = batch["encoder_mask"][:, -1:] - for _ in range(max_generated_tokens): - if tok.item() in transform.stop_tokens: - break - logits = model(tok, encoder_mask=cache_mask)[:, -1] - tok = sample(logits, temperature, top_k) - generated_tokens.append(tok.item()) - - print(transform.decode(generated_tokens)) - - - -""":md -## Chat Pseudo Code - -This approach guarantees that there's only one image cached at a time so that there's no need for cross attention masking. -This works because Llama3v is trained such that each token is only allowed to attend to the previous image and the rest are -masked during training/finetuning. Since consecutive images are treated as one image for Llama3v, you can control the maximum -encoder sequence length by setting max_consecuitve here, as well as by settin max_num_tiles and max_resolution for the image input. - -```python -model.eval() -model.setup_caches(1, torch.bf16) - -with torch.no_grad(): - # Prefill system prompt - toks, _ = transform(parse_prompt(system_prompt)) - model(toks) - while True: - # Prefill user prompt split over images - user_prompt = input(">>> ") - toks, imgs = transform(parse_prompt(user_prompt)) - for i, tok in enumerate(split(toks, image_token, max_consecutive=1)): - img = None - if imgs is not None: - img = imgs[i] - reset_attn_cache(model) - logits = model(tok, img) - - # Decode assitant response - tok = sample_tok(logits) # only ouptput single token logits when model.cache_enabled=True - while tok != EOS: - logits = model(tok) - tok = sample_tok(logits) - sys.stdout.buffer.write(transform.decode(tok)) -``` -""" - -""":py""" diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index feffb418f..0b1dca4cf 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -62,7 +62,7 @@ def apply_tp( # after we apply TP to the model. Because we don't want to change model code # when applying TP. We need to have change to ensure KVCache has the correct # size as k and v. - model.get_text_transformer_args.n_local_heads = model.get_text_transformer_args.n_local_heads // tp_mesh.size() + model.text_transformer_args.n_local_heads = model.text_transformer_args.n_local_heads // tp_mesh.size() # Apply tensor parallelism to every transformer block for transformer_block in model.layers: diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index a3248ea2f..2933edade 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -563,7 +563,7 @@ def _initialize_model( model.setup_caches( max_batch_size=1, max_seq_length=max_seq_length - or model.get_text_transformer_args.max_seq_length, + or model.text_transformer_args.max_seq_length, ) model.to(dtype=builder_args.precision) diff --git a/torchchat/export.py b/torchchat/export.py index 17282d179..affb8b871 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -54,7 +54,7 @@ def export_for_server( torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), ) - seq = Dim("seq", min=1, max=model.get_text_transformer_args.max_seq_length) + seq = Dim("seq", min=1, max=model.text_transformer_args.max_seq_length) # Specify that the first dimension of each input is that batch size dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}} else: diff --git a/torchchat/model.py b/torchchat/model.py index c5aebd6bf..39b10ff68 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -332,9 +332,10 @@ def forward(self, *args, **kwargs): def setup_caches(self, *args, **kwargs): raise NotImplementedError("setup_caches method is not implemented") + @property @abstractmethod - def get_text_transformer_args(self): - raise NotImplementedError("get_text_transformer_args method is not implemented") + def text_transformer_args(self): + raise NotImplementedError("no text_transformer_args is created") @classmethod def _get_model_instance(cls, config: ModelArgs): @@ -376,7 +377,8 @@ def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: def setup_caches(self, max_batch_size, max_seq_length): self.model.setup_caches(max_batch_size, max_seq_length) - def get_text_transformer_args(self): + @property + def text_transformer_args(self): return self.model.model.config @@ -390,7 +392,8 @@ def setup_caches(self, max_batch_size, dtype): def reset_caches(self): self.model.reset_caches() - def get_text_transformer_args(self): + @property + def text_transformer_args(self): # TODO: add support for llama3_1 return None @@ -414,7 +417,8 @@ def setup_caches(self, max_batch_size, dtype): def reset_caches(self): self.model.reset_caches() - def get_text_transformer_args(self): + @property + def text_transformer_args(self): # TODO: add support for flamingo return None @@ -807,7 +811,9 @@ class PTEModel(nn.Module): def __init__(self, config, path) -> None: super().__init__() self.config = config - self.model_ = exec_lib._load_for_executorch(str(path)) + self.model_ = exec_lib._load_for_executorch(str(path)) + + self.text_transformer_config = TransformerArgs.from_params(self.config.transformer_args["text"]) def forward(self, x, input_pos): # model_.forward expects inputs to be wrapped in a tuple @@ -823,14 +829,5 @@ def forward(self, x, input_pos): def setup_caches(self, max_batch_size, max_seq_length): pass - def get_text_transformer_args(self): - # A hacky way to get the model config from the self.model, making it consistent with Model class - # TODO: remove the hacky way once get rid of model.model - try: - text_transformer_config = TransformerArgs.from_params(self.config.transformer_args["text"]) - except: - text_transformer_config = None - return text_transformer_config - except: pass diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index 34de88a68..5993c3781 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -59,7 +59,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( T = prompt.size(0) T_new = T + max_new_tokens if max_seq_length is None: - max_seq_length = min(T_new, model.get_text_transformer_args.block_size) + max_seq_length = min(T_new, model.text_transformer_args.block_size) device, dtype = prompt.device, prompt.dtype # create an empty tensor of the expected final shape and diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py index d694afcd7..6381e8112 100644 --- a/torchchat/usages/openai_api.py +++ b/torchchat/usages/openai_api.py @@ -284,11 +284,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) try: self.max_seq_length = ( - self.model.get_text_transformer_args.max_seq_length + self.model.text_transformer_args.max_seq_length + self.speculative_builder_args.speculate_k + 1 if self.draft_model is not None - else self.model.get_text_transformer_args.max_seq_length + else self.model.text_transformer_args.max_seq_length ) except: # can not find max_seq_length in model config, use default value From 1eff9391bdc197e9a4cf90e0fe5ace8f35a298cb Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 17 Sep 2024 01:25:50 -0700 Subject: [PATCH 19/27] make text_transformer_args a real attribute --- torchchat/model.py | 27 +++++++-------------------- torchchat/utils/gguf_loader.py | 16 +++++++--------- 2 files changed, 14 insertions(+), 29 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index 39b10ff68..f34db26bc 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -304,6 +304,7 @@ def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config self.model = self.build_model() + self.text_transformer_args = None def build_model(self) -> nn.Module: """ @@ -331,11 +332,6 @@ def forward(self, *args, **kwargs): @abstractmethod def setup_caches(self, *args, **kwargs): raise NotImplementedError("setup_caches method is not implemented") - - @property - @abstractmethod - def text_transformer_args(self): - raise NotImplementedError("no text_transformer_args is created") @classmethod def _get_model_instance(cls, config: ModelArgs): @@ -371,15 +367,15 @@ def from_gguf(cls, gguf_path: str, **kwargs): class TextOnlyModel(Model): + def __init__(self, config: ModelArgs) -> None: + super().__init__(config) + self.text_transformer_args = self.model.config + def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: return self.model(tokens, input_pos) def setup_caches(self, max_batch_size, max_seq_length): self.model.setup_caches(max_batch_size, max_seq_length) - - @property - def text_transformer_args(self): - return self.model.model.config class Llama31Model(Model): @@ -391,11 +387,6 @@ def setup_caches(self, max_batch_size, dtype): def reset_caches(self): self.model.reset_caches() - - @property - def text_transformer_args(self): - # TODO: add support for llama3_1 - return None class FlamingoModel(Model): @@ -416,11 +407,7 @@ def setup_caches(self, max_batch_size, dtype): def reset_caches(self): self.model.reset_caches() - - @property - def text_transformer_args(self): - # TODO: add support for flamingo - return None + MODEL_TYPE_TO_CLASS = { @@ -813,7 +800,7 @@ def __init__(self, config, path) -> None: self.config = config self.model_ = exec_lib._load_for_executorch(str(path)) - self.text_transformer_config = TransformerArgs.from_params(self.config.transformer_args["text"]) + self.text_transformer_args = TransformerArgs.from_params(self.config.transformer_args["text"]) def forward(self, x, input_pos): # model_.forward expects inputs to be wrapped in a tuple diff --git a/torchchat/utils/gguf_loader.py b/torchchat/utils/gguf_loader.py index 8fdadf5bf..309ff807c 100644 --- a/torchchat/utils/gguf_loader.py +++ b/torchchat/utils/gguf_loader.py @@ -544,15 +544,13 @@ def load_model(gguf_file: str) -> torch.nn.Module: model_args = ModelArgs( { "text": { - "config": { - "dim": metadata[f"{arch}.embedding_length"], - "n_layers": metadata[f"{arch}.block_count"], - "n_heads": metadata[f"{arch}.attention.head_count"], - "n_local_heads": metadata[f"{arch}.attention.head_count_kv"], - "vocab_size": len(metadata["tokenizer.ggml.tokens"]), - "norm_eps": metadata[f"{arch}.attention.layer_norm_rms_epsilon"], - "hidden_dim": metadata[f"{arch}.feed_forward_length"], - } + "dim": metadata[f"{arch}.embedding_length"], + "n_layers": metadata[f"{arch}.block_count"], + "n_heads": metadata[f"{arch}.attention.head_count"], + "n_local_heads": metadata[f"{arch}.attention.head_count_kv"], + "vocab_size": len(metadata["tokenizer.ggml.tokens"]), + "norm_eps": metadata[f"{arch}.attention.layer_norm_rms_epsilon"], + "hidden_dim": metadata[f"{arch}.feed_forward_length"], } } ) From a356897c6c3bc9c31c9dec6a09c9cd5588e72a82 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 17 Sep 2024 01:28:16 -0700 Subject: [PATCH 20/27] get rid of model.model --- torchchat/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 6f6b3a082..591882917 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -795,7 +795,7 @@ def chat( # This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong # TODO: unify the max_seq_length config representation. - text_transformer_args = getattr(self.model.model, "config", None) + text_transformer_args = self.model.text_transformer_args max_seq_length = ( text_transformer_args.max_seq_length if text_transformer_args else 2048 ) From cbda8795f46a4d575d047419aa4aedad75ba02c8 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 17 Sep 2024 02:00:47 -0700 Subject: [PATCH 21/27] llava model constuction support --- torchchat/model.py | 48 +++------------------------------------------- 1 file changed, 3 insertions(+), 45 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index f793a3d43..e219d1bfc 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -116,7 +116,8 @@ def forward( encoder_mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, ) -> Tensor: - if encoder_input: + if encoder_input is not None: + encoder_input = encoder_input.view(1, 1, *encoder_input.shape) encoder_output = self.encoder( encoder_input, ) @@ -223,7 +224,7 @@ def _llava(cls): 'encoder': clip_vision_encoder, 'decoder': Transformer }, - fusion_class=DeepFusionModel, + fusion_class=ConcateFusion, ) @classmethod @@ -968,46 +969,3 @@ def setup_caches(self, max_batch_size, max_seq_length): except: pass - - -if __name__ == "__main__": - def prepare_image(target_h: int, target_w: int) -> torch.Tensor: - """Read image into a tensor and resize the image so that it fits in - a target_h x target_w canvas. - - Args: - image (Image): An Image object. - target_h (int): Target height. - target_w (int): Target width. - - Returns: - torch.Tensor: resized image tensor. - """ - image = Image.open( - requests.get( - "https://llava-vl.github.io/static/images/view.jpg", stream=True - ).raw) - - img = torchvision.transforms.functional.pil_to_tensor(image) - # height ratio - ratio_h = img.shape[1] / target_h - # width ratio - ratio_w = img.shape[2] / target_w - # resize the image so that it fits in a target_h x target_w canvas - ratio = max(ratio_h, ratio_w) - output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio)) - img = torchvision.transforms.Resize(size=output_size)(img) - return img - - pre_tokens = torch.tensor([[ 1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, - 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, - 322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155, - 29889, 3148, 1001, 29901, 29871]]) - img = prepare_image(336, 336) - post_tokens = torch.tensor([[29871, 13, 462, 9651, 1724, 526, 278, 2712, 306, 881, - 367, 274, 1300, 2738, 1048, 746, 306, 6493, 1244, 29973, - 319, 1799, 9047, 13566, 29901]]) - - llava_model = Model.from_params("/home/gasoonjia/torchchat/torchchat/model_params/llava-1.5.json") - - llava_model(tokens=pre_tokens, encoder_input=img, post_tokens=post_tokens) From 6fbb4605b233f32c8503e26bd74cf65228011e2e Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 17 Sep 2024 11:22:00 -0700 Subject: [PATCH 22/27] 1/2 solve cache issue --- torchchat/model.py | 108 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 103 insertions(+), 5 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index e219d1bfc..07f85b499 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -86,7 +86,7 @@ def __init__( token_embedding_name="tok_embeddings", mm_proj_in_channels=1024, mm_proj_out_channels=4096, - mm_proj_activation=nn.GELU, + mm_proj_activation=nn.GELU(), ): super().__init__() self.encoder = encoder @@ -130,6 +130,9 @@ def forward( ) return self.decoder(decoder_input, input_pos=input_pos) + def setup_caches(self, batch_size, max_seq_len): + self.decoder.setup_caches(batch_size, max_seq_len) + def _encoder_feature_select(self, encoder_output): selected_image_feature = encoder_output[1][0].view( *encoder_output[1][0].shape[2:] @@ -145,19 +148,19 @@ def _get_decoder_input( encoder_output: Optional[Tensor], post_tokens: Optional[Tensor], ): - assert bool(encoder_output) == bool( - post_tokens - ), "encoder_input and post_tokens must be both None or not None" if encoder_output is None: + assert post_tokens is None return self.tok_embeddings(tokens) else: pre_img_embed = self.tok_embeddings(tokens) image_embeds = self.mm_projector(encoder_output) + if post_tokens is None: + return torch.cat((pre_img_embed, image_embeds), dim=1) + post_img_embed = self.tok_embeddings(post_tokens) return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1) - class ModelType(Enum): TextOnly = "text_only" Llama3_1 = "llama3_1" @@ -969,3 +972,98 @@ def setup_caches(self, max_batch_size, max_seq_length): except: pass + + +if __name__ == "__main__": + def prepare_image(target_h: int, target_w: int) -> torch.Tensor: + """Read image into a tensor and resize the image so that it fits in + a target_h x target_w canvas. + + Args: + image (Image): An Image object. + target_h (int): Target height. + target_w (int): Target width. + + Returns: + torch.Tensor: resized image tensor. + """ + image = Image.open( + requests.get( + "https://llava-vl.github.io/static/images/view.jpg", stream=True + ).raw) + + img = torchvision.transforms.functional.pil_to_tensor(image) + # height ratio + ratio_h = img.shape[1] / target_h + # width ratio + ratio_w = img.shape[2] / target_w + # resize the image so that it fits in a target_h x target_w canvas + ratio = max(ratio_h, ratio_w) + output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio)) + img = torchvision.transforms.Resize(size=output_size)(img) + return img + + + def image_preprocess(img: torch.Tensor, target_h: int, target_w: int, rescale_factor, image_mean, image_std) -> torch.Tensor: + # pad the image with median rgb value, to make a square + l_pad = (target_w - img.shape[2]) // 2 + t_pad = (target_h - img.shape[1]) // 2 + # ceil division + r_pad = -((target_w - img.shape[2]) // -2) + b_pad = -((target_h - img.shape[1]) // -2) + + torch._check(l_pad >= 0) + torch._check(t_pad >= 0) + torch._check(r_pad >= 0) + torch._check(b_pad >= 0) + + # This is different from the original implementation, due to export limitations. + resized = torch.nn.functional.pad( + img, + (l_pad, r_pad, t_pad, b_pad), + ) + # originally: + # resized = F.pad( + # img, + # padding=(l_pad, t_pad, r_pad, b_pad), + # fill=tuple(int(x * 255) for x in self.image_mean), + # ) + + # TODO: implement _upsample_bicubic_aa.out in portable kernel library. + # here padded shape should be max(h, w) x max(h, w) + # skipping resize for now due to missing _upsample_bicubic_aa kernel in portable + # resized = resize( + # padded, + # size=[ + # self.image_processor.crop_size["height"], + # self.image_processor.crop_size["width"], + # ], + # interpolation="bicubic", + # ) + # torch._check(resized.size(1) == self.config.crop_size["height"]) + # torch._check(resized.size(2) == self.config.crop_size["width"]) + # print(resized.shape) + # cropped = F.center_crop(img, output_size=[w, w]) + # print(cropped.shape) + scaled = resized * rescale_factor + # print(scaled) + from torchvision.transforms.v2 import functional as tvF + normed = tvF.normalize( + scaled, image_mean, image_std + ) + # print(normed) + return normed.unsqueeze(0) + + pre_tokens = torch.tensor([[ 1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, + 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, + 322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155, + 29889, 3148, 1001, 29901, 29871]]) + img = prepare_image(336, 336) + post_tokens = torch.tensor([[29871, 13, 462, 9651, 1724, 526, 278, 2712, 306, 881, + 367, 274, 1300, 2738, 1048, 746, 306, 6493, 1244, 29973, + 319, 1799, 9047, 13566, 29901]]) + + llava_model = Model.from_params("/home/gasoonjia/torchchat/torchchat/model_params/llava-1.5.json") + llava_model.setup_caches(1, 2048) + img = image_preprocess(img=img, target_h=336, target_w=336, image_mean=[0.48145466, 0.4578275, 0.40821073], image_std=[0.26862954, 0.26130258, 0.27577711], rescale_factor=0.00392156862745098) + llava_model(tokens=pre_tokens, encoder_input=img, post_tokens=post_tokens) From f224da756c61563ac94aab3cac3f8f13b30a05ec Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 17 Sep 2024 11:31:04 -0700 Subject: [PATCH 23/27] solve comments --- torchchat/generate.py | 7 ++++++- torchchat/model.py | 5 +++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 591882917..4c8b02e91 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -812,7 +812,12 @@ def chat( elif not generator_args.is_torchtune_model: max_seq_length = min( - encoded.size(0) + generator_args.max_new_tokens, max_seq_length + encoded.size(0) + generator_args.max_new_tokens, + ( + text_transformer_args.block_size + if text_transformer_args is not None + else 2048 + ), ) max_seq_length = ( diff --git a/torchchat/model.py b/torchchat/model.py index f34db26bc..f936970dd 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -170,6 +170,8 @@ class ModelArgs: transformer_args (Dict[str, Dict[str, Any]]): A dictionary containing the parameters for each transformer in the model. The outer dictionary has transformer names as keys and inner dictionaries as values. Each inner dictionary contains the parameter names and their corresponding values for the respective transformer. + TODO: econcile Dict[str, Any] into tranformer-arg-family classes in future PRs. + use_tiktoken (bool): A flag indicating whether to use TikToken as the tokenizer for the model. Note: It is recommended to use factory functions to create instances of this class instead of directly using the constructor. @@ -304,6 +306,9 @@ def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config self.model = self.build_model() + + # text_transformer_args represents the args for the text transformer in the model. + # It should be assigned in the actual model implementation, if any. self.text_transformer_args = None def build_model(self) -> nn.Module: From 128566c46786e777c3177f7afd2e65667a281827 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 17 Sep 2024 13:57:15 -0700 Subject: [PATCH 24/27] prepare for rebase --- torchchat/model.py | 103 ++++----------------------------------------- 1 file changed, 8 insertions(+), 95 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index d7b19b1c5..15731baad 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -128,6 +128,14 @@ def forward( decoder_input = self._get_decoder_input( tokens, encoder_output=encoder_output, post_tokens=post_tokens ) + + if input_pos is None: + input_pos = torch.arange( + decoder_input.shape[1], + device=decoder_input.device, + dtype=torch.int, + ) + return self.decoder(decoder_input, input_pos=input_pos) def setup_caches(self, batch_size, max_seq_len): @@ -977,98 +985,3 @@ def setup_caches(self, max_batch_size, max_seq_length): except: pass - - -if __name__ == "__main__": - def prepare_image(target_h: int, target_w: int) -> torch.Tensor: - """Read image into a tensor and resize the image so that it fits in - a target_h x target_w canvas. - - Args: - image (Image): An Image object. - target_h (int): Target height. - target_w (int): Target width. - - Returns: - torch.Tensor: resized image tensor. - """ - image = Image.open( - requests.get( - "https://llava-vl.github.io/static/images/view.jpg", stream=True - ).raw) - - img = torchvision.transforms.functional.pil_to_tensor(image) - # height ratio - ratio_h = img.shape[1] / target_h - # width ratio - ratio_w = img.shape[2] / target_w - # resize the image so that it fits in a target_h x target_w canvas - ratio = max(ratio_h, ratio_w) - output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio)) - img = torchvision.transforms.Resize(size=output_size)(img) - return img - - - def image_preprocess(img: torch.Tensor, target_h: int, target_w: int, rescale_factor, image_mean, image_std) -> torch.Tensor: - # pad the image with median rgb value, to make a square - l_pad = (target_w - img.shape[2]) // 2 - t_pad = (target_h - img.shape[1]) // 2 - # ceil division - r_pad = -((target_w - img.shape[2]) // -2) - b_pad = -((target_h - img.shape[1]) // -2) - - torch._check(l_pad >= 0) - torch._check(t_pad >= 0) - torch._check(r_pad >= 0) - torch._check(b_pad >= 0) - - # This is different from the original implementation, due to export limitations. - resized = torch.nn.functional.pad( - img, - (l_pad, r_pad, t_pad, b_pad), - ) - # originally: - # resized = F.pad( - # img, - # padding=(l_pad, t_pad, r_pad, b_pad), - # fill=tuple(int(x * 255) for x in self.image_mean), - # ) - - # TODO: implement _upsample_bicubic_aa.out in portable kernel library. - # here padded shape should be max(h, w) x max(h, w) - # skipping resize for now due to missing _upsample_bicubic_aa kernel in portable - # resized = resize( - # padded, - # size=[ - # self.image_processor.crop_size["height"], - # self.image_processor.crop_size["width"], - # ], - # interpolation="bicubic", - # ) - # torch._check(resized.size(1) == self.config.crop_size["height"]) - # torch._check(resized.size(2) == self.config.crop_size["width"]) - # print(resized.shape) - # cropped = F.center_crop(img, output_size=[w, w]) - # print(cropped.shape) - scaled = resized * rescale_factor - # print(scaled) - from torchvision.transforms.v2 import functional as tvF - normed = tvF.normalize( - scaled, image_mean, image_std - ) - # print(normed) - return normed.unsqueeze(0) - - pre_tokens = torch.tensor([[ 1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, - 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, - 322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155, - 29889, 3148, 1001, 29901, 29871]]) - img = prepare_image(336, 336) - post_tokens = torch.tensor([[29871, 13, 462, 9651, 1724, 526, 278, 2712, 306, 881, - 367, 274, 1300, 2738, 1048, 746, 306, 6493, 1244, 29973, - 319, 1799, 9047, 13566, 29901]]) - - llava_model = Model.from_params("/home/gasoonjia/torchchat/torchchat/model_params/llava-1.5.json") - llava_model.setup_caches(1, 2048) - img = image_preprocess(img=img, target_h=336, target_w=336, image_mean=[0.48145466, 0.4578275, 0.40821073], image_std=[0.26862954, 0.26130258, 0.27577711], rescale_factor=0.00392156862745098) - llava_model(tokens=pre_tokens, encoder_input=img, post_tokens=post_tokens) From 7aab3b4d9b6fe4d4a7e0bb5209ecb39e07ba0898 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 17 Sep 2024 14:03:26 -0700 Subject: [PATCH 25/27] bring license back --- torchchat/model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchchat/model.py b/torchchat/model.py index e47dd2048..f211533c9 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. import json import os import warnings From 7ffec7384693119d4d01096ac9fbcf30c775bd70 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 17 Sep 2024 16:32:30 -0700 Subject: [PATCH 26/27] solve comments --- torchchat/model.py | 56 +++++++++++++++++----------------------------- 1 file changed, 21 insertions(+), 35 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index f211533c9..f4f6624f9 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -11,8 +11,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path -from PIL import Image -import requests + import torchvision from typing import Any, Callable, Dict, Optional, Union @@ -35,14 +34,10 @@ from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder from torchtune.modules.model_fusion import DeepFusionModel +from torchtune.models.clip import clip_vision_encoder from torchchat.utils.build_utils import find_multiple, get_precision -from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder -from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder -from torchtune.modules.model_fusion import DeepFusionModel -from torchtune.models.clip import clip_vision_encoder - config_path = Path(f"{str(Path(__file__).parent)}/model_params") @@ -61,19 +56,13 @@ def identity(**kwargs): return list(kwargs.values())[0] -@dataclass -class ProjectorArgs: - in_channels: int = 1024 - out_channels: int = 4096 - activation: nn.Module = nn.GELU() - class MultiModalProjector(nn.Module): - def __init__(self, args: ProjectorArgs): + def __init__(self, in_channels: int, out_channels: int, act: nn.Module): super().__init__() self.linear_1 = nn.Linear(args.in_channels, args.out_channels, bias=True) - self.act = args.activation + self.act = act self.linear_2 = nn.Linear(args.out_channels, args.out_channels, bias=True) def forward(self, image_features): @@ -105,11 +94,9 @@ def __init__( self.decoder.__setattr__(token_embedding_name, None) self.mm_projector = MultiModalProjector( - ProjectorArgs( in_channels=mm_proj_in_channels, out_channels=mm_proj_out_channels, - activation=mm_proj_activation, - ) + act=mm_proj_activation, ) def forward( @@ -123,9 +110,7 @@ def forward( ) -> Tensor: if encoder_input is not None: encoder_input = encoder_input.view(1, 1, *encoder_input.shape) - encoder_output = self.encoder( - encoder_input, - ) + encoder_output = self.encoder(encoder_input) encoder_output = self._encoder_feature_select(encoder_output) else: encoder_output = None @@ -143,10 +128,10 @@ def forward( return self.decoder(decoder_input, input_pos=input_pos) - def setup_caches(self, batch_size, max_seq_len): + def setup_caches(self, batch_size, max_seq_len) -> None: self.decoder.setup_caches(batch_size, max_seq_len) - def _encoder_feature_select(self, encoder_output): + def _encoder_feature_select(self, encoder_output) -> Tensor: selected_image_feature = encoder_output[1][0].view( *encoder_output[1][0].shape[2:] ) @@ -160,7 +145,7 @@ def _get_decoder_input( *, encoder_output: Optional[Tensor], post_tokens: Optional[Tensor], - ): + ) -> Tensor: if encoder_output is None: assert post_tokens is None return self.tok_embeddings(tokens) @@ -245,16 +230,17 @@ def _llava(cls): @classmethod def get_recipe(cls, model_type): - if model_type == ModelType.TextOnly: - return cls._text_only() - elif model_type == ModelType.Flamingo: - return cls._flamingo() - elif model_type == ModelType.Llama3_1: - return cls._llama3_1() - elif model_type == ModelType.Llava: - return cls._llava() - else: - raise ValueError(f"Can not find the model recipe for {model_type}") + match model_type: + case ModelType.TextOnly: + return cls._text_only() + case ModelType.Flamingo: + return cls._flamingo() + case ModelType.Llama3_1: + return cls._llama3_1() + case ModelType.Llava: + return cls._llava() + case _: + raise ValueError(f"Can not find the model recipe for {model_type}") @dataclass @@ -475,7 +461,7 @@ def build_model(self) -> nn.Module: return recipe.fusion_class(**modules) - def _replace_know_params(self, params): + def _replace_known_params(self, params): patterns = {"QuickGELUActivation()": QuickGELUActivation()} for key, value in params.items(): if isinstance(value, Hashable) and value in patterns: From 672915ac34f34809a168b54960ffb19b488b0245 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 17 Sep 2024 16:43:45 -0700 Subject: [PATCH 27/27] remove extra arg. --- torchchat/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchchat/model.py b/torchchat/model.py index f4f6624f9..a576d5036 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -61,9 +61,9 @@ class MultiModalProjector(nn.Module): def __init__(self, in_channels: int, out_channels: int, act: nn.Module): super().__init__() - self.linear_1 = nn.Linear(args.in_channels, args.out_channels, bias=True) + self.linear_1 = nn.Linear(in_channels, out_channels, bias=True) self.act = act - self.linear_2 = nn.Linear(args.out_channels, args.out_channels, bias=True) + self.linear_2 = nn.Linear(out_channels, out_channels, bias=True) def forward(self, image_features): hidden_states = self.linear_1(image_features)