Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f4bf00b
llava init
Gasoonjia Sep 15, 2024
2cabbe7
2/n llava init
Gasoonjia Sep 15, 2024
353fafe
3/n llava init
Gasoonjia Sep 16, 2024
728fc46
reformat llava
Gasoonjia Sep 16, 2024
215331d
3/n llava
Gasoonjia Sep 16, 2024
23d6504
llava config update
Gasoonjia Sep 16, 2024
22fd2a5
4/n llava init
Gasoonjia Sep 16, 2024
fff8647
unify model construction ppl
Gasoonjia Sep 16, 2024
4b666a7
update transformer config
Gasoonjia Sep 16, 2024
cc8b4d6
update model config for gguf
Gasoonjia Sep 17, 2024
7ec018a
hack PTEModel to have same config hirearchy as Model
Gasoonjia Sep 17, 2024
94e56f1
unify model construction ppl
Gasoonjia Sep 16, 2024
2e3d1dc
Merge branch 'main' into unify-constuct-model
Gasoonjia Sep 17, 2024
cbadc92
merge with unified model contruction pipeline
Gasoonjia Sep 17, 2024
43dfdc7
5/n torchchat init
Gasoonjia Sep 17, 2024
63d76a1
hack PTEModel to support current ppl
Gasoonjia Sep 17, 2024
01bb624
fix a typo
Gasoonjia Sep 17, 2024
319ac86
unify model construction ppl
Gasoonjia Sep 16, 2024
141fea0
rebase and solve comments
Gasoonjia Sep 17, 2024
8cd0936
bring TransformerArgs back to Transformer
Gasoonjia Sep 17, 2024
304fece
rename get_text_transformer_args as text_transformer_args for readibi…
Gasoonjia Sep 17, 2024
1eff939
make text_transformer_args a real attribute
Gasoonjia Sep 17, 2024
a356897
get rid of model.model
Gasoonjia Sep 17, 2024
a190b0f
merge with unified model contruction pipeline
Gasoonjia Sep 17, 2024
cbda879
llava model constuction support
Gasoonjia Sep 17, 2024
6fbb460
1/2 solve cache issue
Gasoonjia Sep 17, 2024
f224da7
solve comments
Gasoonjia Sep 17, 2024
83f8501
Merge branch 'unify-constuct-model' into llava-support
Gasoonjia Sep 17, 2024
128566c
prepare for rebase
Gasoonjia Sep 17, 2024
f3cbd53
merge with main
Gasoonjia Sep 17, 2024
7aab3b4
bring license back
Gasoonjia Sep 17, 2024
7ffec73
solve comments
Gasoonjia Sep 17, 2024
672915a
remove extra arg.
Gasoonjia Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 168 additions & 1 deletion torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from PIL import Image
import requests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure that this is downloaded in install requirements (it probably already is)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they can be removed rn; they should be sth in the 3/n pr haha

import torchvision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this torchvision import seems unused


from typing import Any, Callable, Dict, Optional, Union
from collections.abc import Hashable

import torch
import torch.nn as nn
Expand All @@ -34,19 +38,147 @@

from torchchat.utils.build_utils import find_multiple, get_precision

from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.models.clip import clip_vision_encoder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebase is fun and makes clones


config_path = Path(f"{str(Path(__file__).parent)}/model_params")


class QuickGELUActivation(nn.Module):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""

def forward(self, input):
return input * torch.sigmoid(1.702 * input)


def identity(**kwargs):
if len(kwargs) != 1:
raise ValueError("Only one argument is expected")
return list(kwargs.values())[0]


@dataclass
class ProjectorArgs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add docstring for each of these dataclasses since they not the usuall Llama classes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I can remove it rn. We can further take it back when we design arg class for different modules.

in_channels: int = 1024
out_channels: int = 4096
activation: nn.Module = nn.GELU()


class MultiModalProjector(nn.Module):
def __init__(self, args: ProjectorArgs):
super().__init__()

self.linear_1 = nn.Linear(args.in_channels, args.out_channels, bias=True)
self.act = args.activation
self.linear_2 = nn.Linear(args.out_channels, args.out_channels, bias=True)

def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states


class ConcateFusion(nn.Module):
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
token_embedding_name="tok_embeddings",
mm_proj_in_channels=1024,
mm_proj_out_channels=4096,
mm_proj_activation=nn.GELU(),
):
super().__init__()
self.encoder = encoder
self.decoder = decoder

# esclate the embedding layer outside decoder llava model need to fuse
# the text and image embedding together before passing to decoder.
self.tok_embeddings = getattr(self.decoder, token_embedding_name)

# set the embedding layer in decoder to None to jump the embedding layer over in decoder
self.decoder.__setattr__(token_embedding_name, None)

self.mm_projector = MultiModalProjector(
ProjectorArgs(
in_channels=mm_proj_in_channels,
out_channels=mm_proj_out_channels,
activation=mm_proj_activation,
)
)

def forward(
self,
tokens: Tensor,
*,
post_tokens: Optional[Tensor] = None,
encoder_input: Optional[Tensor] = None,
encoder_mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> Tensor:
if encoder_input is not None:
encoder_input = encoder_input.view(1, 1, *encoder_input.shape)
encoder_output = self.encoder(
encoder_input,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
encoder_output = self.encoder(
encoder_input,
)
encoder_output = self.encoder(encoder_input)

encoder_output = self._encoder_feature_select(encoder_output)
else:
encoder_output = None

decoder_input = self._get_decoder_input(
tokens, encoder_output=encoder_output, post_tokens=post_tokens
)

if input_pos is None:
input_pos = torch.arange(
decoder_input.shape[1],
device=decoder_input.device,
dtype=torch.int,
)

return self.decoder(decoder_input, input_pos=input_pos)

def setup_caches(self, batch_size, max_seq_len):
self.decoder.setup_caches(batch_size, max_seq_len)

def _encoder_feature_select(self, encoder_output):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

REturn type

selected_image_feature = encoder_output[1][0].view(
*encoder_output[1][0].shape[2:]
)

selected_image_feature = selected_image_feature[:, 1:]
return selected_image_feature

def _get_decoder_input(
self,
tokens: Tensor,
*,
encoder_output: Optional[Tensor],
post_tokens: Optional[Tensor],
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return type

if encoder_output is None:
assert post_tokens is None
return self.tok_embeddings(tokens)
else:
pre_img_embed = self.tok_embeddings(tokens)
image_embeds = self.mm_projector(encoder_output)
if post_tokens is None:
return torch.cat((pre_img_embed, image_embeds), dim=1)

post_img_embed = self.tok_embeddings(post_tokens)
return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1)


class ModelType(Enum):
TextOnly = "text_only"
Llama3_1 = "llama3_1"
Flamingo = "flamingo"
Llava = "llava"


# Type for objects that can generate nn.Module instance
Expand Down Expand Up @@ -100,6 +232,17 @@ def _flamingo(cls):
fusion_class=DeepFusionModel,
)

@classmethod
def _llava(cls):
return cls(
model_type=ModelType.Llava,
modules={
'encoder': clip_vision_encoder,
'decoder': Transformer
},
fusion_class=ConcateFusion,
Comment on lines +224 to +228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's really cool to see them working together!

)

@classmethod
def get_recipe(cls, model_type):
if model_type == ModelType.TextOnly:
Expand All @@ -108,6 +251,8 @@ def get_recipe(cls, model_type):
return cls._flamingo()
elif model_type == ModelType.Llama3_1:
return cls._llama3_1()
elif model_type == ModelType.Llava:
return cls._llava()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match model_type:
    case ModelType.TextOnly:
        return cls._text_only()
    case ModelType.Flamingo:
        return cls.flamingo()
...

Copy link
Contributor Author

@Gasoonjia Gasoonjia Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh YEAH we are in 3.10, it is a good timing to switch to match case statement!

else:
raise ValueError(f"Can not find the model recipe for {model_type}")

Expand Down Expand Up @@ -329,7 +474,14 @@ def build_model(self) -> nn.Module:
modules[name] = module_class(**config_args)

return recipe.fusion_class(**modules)


def _replace_know_params(self, params):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _replace_know_params(self, params):
def _replace_known_params(self, params):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stupid grammar issue

patterns = {"QuickGELUActivation()": QuickGELUActivation()}
for key, value in params.items():
if isinstance(value, Hashable) and value in patterns:
params[key] = patterns[value]
return params

@abstractmethod
def forward(self, *args, **kwargs):
raise NotImplementedError("forward method is not implemented")
Expand Down Expand Up @@ -414,11 +566,26 @@ def reset_caches(self):
self.model.reset_caches()


class LlavaModel(Model):
def forward(
self,
tokens: Tensor,
*,
encoder_input: Optional[Dict[str, Tensor]] = None,
post_tokens: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
) -> Tensor:
return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos)

def setup_caches(self, max_batch_size, max_seq_length):
self.model.setup_caches(max_batch_size, max_seq_length)


MODEL_TYPE_TO_CLASS = {
ModelType.TextOnly: TextOnlyModel,
ModelType.Flamingo: FlamingoModel,
ModelType.Llama3_1: Llama31Model,
ModelType.Llava: LlavaModel,
}


Expand Down
25 changes: 25 additions & 0 deletions torchchat/model_params/llava-1.5.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"model_type": "llava",
"use_tiktoken": true,
"encoder": {
"tile_size": 336,
"patch_size": 14,
"embed_dim": 1024,
"num_layers": 24,
"num_heads": 16,
"out_indices": [
23
],
"output_cls_projection": false,
"max_num_tiles": 1,
"in_channels": 3,
"intermediate_act": "QuickGELUActivation()"
},
"decoder": {
"n_layers": 32,
"n_heads": 32,
"dim": 4096,
"vocab_size": 32064,
"max_seq_length": 768
}
}