Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f4bf00b
llava init
Gasoonjia Sep 15, 2024
2cabbe7
2/n llava init
Gasoonjia Sep 15, 2024
353fafe
3/n llava init
Gasoonjia Sep 16, 2024
728fc46
reformat llava
Gasoonjia Sep 16, 2024
215331d
3/n llava
Gasoonjia Sep 16, 2024
23d6504
llava config update
Gasoonjia Sep 16, 2024
22fd2a5
4/n llava init
Gasoonjia Sep 16, 2024
fff8647
unify model construction ppl
Gasoonjia Sep 16, 2024
4b666a7
update transformer config
Gasoonjia Sep 16, 2024
cc8b4d6
update model config for gguf
Gasoonjia Sep 17, 2024
7ec018a
hack PTEModel to have same config hirearchy as Model
Gasoonjia Sep 17, 2024
94e56f1
unify model construction ppl
Gasoonjia Sep 16, 2024
2e3d1dc
Merge branch 'main' into unify-constuct-model
Gasoonjia Sep 17, 2024
cbadc92
merge with unified model contruction pipeline
Gasoonjia Sep 17, 2024
43dfdc7
5/n torchchat init
Gasoonjia Sep 17, 2024
63d76a1
hack PTEModel to support current ppl
Gasoonjia Sep 17, 2024
01bb624
fix a typo
Gasoonjia Sep 17, 2024
319ac86
unify model construction ppl
Gasoonjia Sep 16, 2024
141fea0
rebase and solve comments
Gasoonjia Sep 17, 2024
8cd0936
bring TransformerArgs back to Transformer
Gasoonjia Sep 17, 2024
304fece
rename get_text_transformer_args as text_transformer_args for readibi…
Gasoonjia Sep 17, 2024
1eff939
make text_transformer_args a real attribute
Gasoonjia Sep 17, 2024
a356897
get rid of model.model
Gasoonjia Sep 17, 2024
a190b0f
merge with unified model contruction pipeline
Gasoonjia Sep 17, 2024
cbda879
llava model constuction support
Gasoonjia Sep 17, 2024
6fbb460
1/2 solve cache issue
Gasoonjia Sep 17, 2024
f224da7
solve comments
Gasoonjia Sep 17, 2024
83f8501
Merge branch 'unify-constuct-model' into llava-support
Gasoonjia Sep 17, 2024
128566c
prepare for rebase
Gasoonjia Sep 17, 2024
f3cbd53
merge with main
Gasoonjia Sep 17, 2024
7aab3b4
bring license back
Gasoonjia Sep 17, 2024
7ffec73
solve comments
Gasoonjia Sep 17, 2024
672915a
remove extra arg.
Gasoonjia Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 162 additions & 9 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from enum import Enum
from pathlib import Path

import torchvision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this torchvision import seems unused


from typing import Any, Callable, Dict, Optional, Union
from collections.abc import Hashable

import torch
import torch.nn as nn
Expand All @@ -31,22 +34,136 @@
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.models.clip import clip_vision_encoder

from torchchat.utils.build_utils import find_multiple, get_precision

config_path = Path(f"{str(Path(__file__).parent)}/model_params")


class QuickGELUActivation(nn.Module):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""

def forward(self, input):
return input * torch.sigmoid(1.702 * input)


def identity(**kwargs):
if len(kwargs) != 1:
raise ValueError("Only one argument is expected")
return list(kwargs.values())[0]



class MultiModalProjector(nn.Module):
def __init__(self, in_channels: int, out_channels: int, act: nn.Module):
super().__init__()

self.linear_1 = nn.Linear(in_channels, out_channels, bias=True)
self.act = act
self.linear_2 = nn.Linear(out_channels, out_channels, bias=True)

def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states


class ConcateFusion(nn.Module):
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
token_embedding_name="tok_embeddings",
mm_proj_in_channels=1024,
mm_proj_out_channels=4096,
mm_proj_activation=nn.GELU(),
):
super().__init__()
self.encoder = encoder
self.decoder = decoder

# esclate the embedding layer outside decoder llava model need to fuse
# the text and image embedding together before passing to decoder.
self.tok_embeddings = getattr(self.decoder, token_embedding_name)

# set the embedding layer in decoder to None to jump the embedding layer over in decoder
self.decoder.__setattr__(token_embedding_name, None)

self.mm_projector = MultiModalProjector(
in_channels=mm_proj_in_channels,
out_channels=mm_proj_out_channels,
act=mm_proj_activation,
)

def forward(
self,
tokens: Tensor,
*,
post_tokens: Optional[Tensor] = None,
encoder_input: Optional[Tensor] = None,
encoder_mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> Tensor:
if encoder_input is not None:
encoder_input = encoder_input.view(1, 1, *encoder_input.shape)
encoder_output = self.encoder(encoder_input)
encoder_output = self._encoder_feature_select(encoder_output)
else:
encoder_output = None

decoder_input = self._get_decoder_input(
tokens, encoder_output=encoder_output, post_tokens=post_tokens
)

if input_pos is None:
input_pos = torch.arange(
decoder_input.shape[1],
device=decoder_input.device,
dtype=torch.int,
)

return self.decoder(decoder_input, input_pos=input_pos)

def setup_caches(self, batch_size, max_seq_len) -> None:
self.decoder.setup_caches(batch_size, max_seq_len)

def _encoder_feature_select(self, encoder_output) -> Tensor:
selected_image_feature = encoder_output[1][0].view(
*encoder_output[1][0].shape[2:]
)

selected_image_feature = selected_image_feature[:, 1:]
return selected_image_feature

def _get_decoder_input(
self,
tokens: Tensor,
*,
encoder_output: Optional[Tensor],
post_tokens: Optional[Tensor],
) -> Tensor:
if encoder_output is None:
assert post_tokens is None
return self.tok_embeddings(tokens)
else:
pre_img_embed = self.tok_embeddings(tokens)
image_embeds = self.mm_projector(encoder_output)
if post_tokens is None:
return torch.cat((pre_img_embed, image_embeds), dim=1)

post_img_embed = self.tok_embeddings(post_tokens)
return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1)


class ModelType(Enum):
TextOnly = "text_only"
Llama3_1 = "llama3_1"
Flamingo = "flamingo"
Llava = "llava"


# Type for objects that can generate nn.Module instance
Expand Down Expand Up @@ -100,16 +217,30 @@ def _flamingo(cls):
fusion_class=DeepFusionModel,
)

@classmethod
def _llava(cls):
return cls(
model_type=ModelType.Llava,
modules={
'encoder': clip_vision_encoder,
'decoder': Transformer
},
fusion_class=ConcateFusion,
Comment on lines +224 to +228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's really cool to see them working together!

)

@classmethod
def get_recipe(cls, model_type):
if model_type == ModelType.TextOnly:
return cls._text_only()
elif model_type == ModelType.Flamingo:
return cls._flamingo()
elif model_type == ModelType.Llama3_1:
return cls._llama3_1()
else:
raise ValueError(f"Can not find the model recipe for {model_type}")
match model_type:
case ModelType.TextOnly:
return cls._text_only()
case ModelType.Flamingo:
return cls._flamingo()
case ModelType.Llama3_1:
return cls._llama3_1()
case ModelType.Llava:
return cls._llava()
case _:
raise ValueError(f"Can not find the model recipe for {model_type}")


@dataclass
Expand Down Expand Up @@ -329,7 +460,14 @@ def build_model(self) -> nn.Module:
modules[name] = module_class(**config_args)

return recipe.fusion_class(**modules)


def _replace_known_params(self, params):
patterns = {"QuickGELUActivation()": QuickGELUActivation()}
for key, value in params.items():
if isinstance(value, Hashable) and value in patterns:
params[key] = patterns[value]
return params

@abstractmethod
def forward(self, *args, **kwargs):
raise NotImplementedError("forward method is not implemented")
Expand Down Expand Up @@ -414,11 +552,26 @@ def reset_caches(self):
self.model.reset_caches()


class LlavaModel(Model):
def forward(
self,
tokens: Tensor,
*,
encoder_input: Optional[Dict[str, Tensor]] = None,
post_tokens: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
) -> Tensor:
return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos)

def setup_caches(self, max_batch_size, max_seq_length):
self.model.setup_caches(max_batch_size, max_seq_length)


MODEL_TYPE_TO_CLASS = {
ModelType.TextOnly: TextOnlyModel,
ModelType.Flamingo: FlamingoModel,
ModelType.Llama3_1: Llama31Model,
ModelType.Llava: LlavaModel,
}


Expand Down
25 changes: 25 additions & 0 deletions torchchat/model_params/llava-1.5.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"model_type": "llava",
"use_tiktoken": true,
"encoder": {
"tile_size": 336,
"patch_size": 14,
"embed_dim": 1024,
"num_layers": 24,
"num_heads": 16,
"out_indices": [
23
],
"output_cls_projection": false,
"max_num_tiles": 1,
"in_channels": 3,
"intermediate_act": "QuickGELUActivation()"
},
"decoder": {
"n_layers": 32,
"n_heads": 32,
"dim": 4096,
"vocab_size": 32064,
"max_seq_length": 768
}
}
Loading