Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3b162e2

Browse files
authored
[llava 2/n] Support Llava Model Construction (#1155)
* llava init * 2/n llava init * 3/n llava init * reformat llava * 3/n llava * llava config update * 4/n llava init * unify model construction ppl * update transformer config * update model config for gguf * hack PTEModel to have same config hirearchy as Model * unify model construction ppl * 5/n torchchat init * hack PTEModel to support current ppl * fix a typo * unify model construction ppl * bring TransformerArgs back to Transformer * rename get_text_transformer_args as text_transformer_args for readibility * make text_transformer_args a real attribute * get rid of model.model * llava model constuction support * 1/2 solve cache issue * solve comments * prepare for rebase * bring license back * solve comments * remove extra arg.
1 parent f730056 commit 3b162e2

File tree

2 files changed

+187
-9
lines changed

2 files changed

+187
-9
lines changed

torchchat/model.py

Lines changed: 162 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from enum import Enum
1313
from pathlib import Path
1414

15+
import torchvision
16+
1517
from typing import Any, Callable, Dict, Optional, Union
18+
from collections.abc import Hashable
1619

1720
import torch
1821
import torch.nn as nn
@@ -31,22 +34,136 @@
3134
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
3235
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
3336
from torchtune.modules.model_fusion import DeepFusionModel
37+
from torchtune.models.clip import clip_vision_encoder
3438

3539
from torchchat.utils.build_utils import find_multiple, get_precision
3640

3741
config_path = Path(f"{str(Path(__file__).parent)}/model_params")
3842

3943

44+
class QuickGELUActivation(nn.Module):
45+
"""
46+
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
47+
"""
48+
49+
def forward(self, input):
50+
return input * torch.sigmoid(1.702 * input)
51+
52+
4053
def identity(**kwargs):
4154
if len(kwargs) != 1:
4255
raise ValueError("Only one argument is expected")
4356
return list(kwargs.values())[0]
4457

4558

59+
60+
class MultiModalProjector(nn.Module):
61+
def __init__(self, in_channels: int, out_channels: int, act: nn.Module):
62+
super().__init__()
63+
64+
self.linear_1 = nn.Linear(in_channels, out_channels, bias=True)
65+
self.act = act
66+
self.linear_2 = nn.Linear(out_channels, out_channels, bias=True)
67+
68+
def forward(self, image_features):
69+
hidden_states = self.linear_1(image_features)
70+
hidden_states = self.act(hidden_states)
71+
hidden_states = self.linear_2(hidden_states)
72+
return hidden_states
73+
74+
75+
class ConcateFusion(nn.Module):
76+
def __init__(
77+
self,
78+
encoder: nn.Module,
79+
decoder: nn.Module,
80+
token_embedding_name="tok_embeddings",
81+
mm_proj_in_channels=1024,
82+
mm_proj_out_channels=4096,
83+
mm_proj_activation=nn.GELU(),
84+
):
85+
super().__init__()
86+
self.encoder = encoder
87+
self.decoder = decoder
88+
89+
# esclate the embedding layer outside decoder llava model need to fuse
90+
# the text and image embedding together before passing to decoder.
91+
self.tok_embeddings = getattr(self.decoder, token_embedding_name)
92+
93+
# set the embedding layer in decoder to None to jump the embedding layer over in decoder
94+
self.decoder.__setattr__(token_embedding_name, None)
95+
96+
self.mm_projector = MultiModalProjector(
97+
in_channels=mm_proj_in_channels,
98+
out_channels=mm_proj_out_channels,
99+
act=mm_proj_activation,
100+
)
101+
102+
def forward(
103+
self,
104+
tokens: Tensor,
105+
*,
106+
post_tokens: Optional[Tensor] = None,
107+
encoder_input: Optional[Tensor] = None,
108+
encoder_mask: Optional[torch.Tensor] = None,
109+
input_pos: Optional[torch.Tensor] = None,
110+
) -> Tensor:
111+
if encoder_input is not None:
112+
encoder_input = encoder_input.view(1, 1, *encoder_input.shape)
113+
encoder_output = self.encoder(encoder_input)
114+
encoder_output = self._encoder_feature_select(encoder_output)
115+
else:
116+
encoder_output = None
117+
118+
decoder_input = self._get_decoder_input(
119+
tokens, encoder_output=encoder_output, post_tokens=post_tokens
120+
)
121+
122+
if input_pos is None:
123+
input_pos = torch.arange(
124+
decoder_input.shape[1],
125+
device=decoder_input.device,
126+
dtype=torch.int,
127+
)
128+
129+
return self.decoder(decoder_input, input_pos=input_pos)
130+
131+
def setup_caches(self, batch_size, max_seq_len) -> None:
132+
self.decoder.setup_caches(batch_size, max_seq_len)
133+
134+
def _encoder_feature_select(self, encoder_output) -> Tensor:
135+
selected_image_feature = encoder_output[1][0].view(
136+
*encoder_output[1][0].shape[2:]
137+
)
138+
139+
selected_image_feature = selected_image_feature[:, 1:]
140+
return selected_image_feature
141+
142+
def _get_decoder_input(
143+
self,
144+
tokens: Tensor,
145+
*,
146+
encoder_output: Optional[Tensor],
147+
post_tokens: Optional[Tensor],
148+
) -> Tensor:
149+
if encoder_output is None:
150+
assert post_tokens is None
151+
return self.tok_embeddings(tokens)
152+
else:
153+
pre_img_embed = self.tok_embeddings(tokens)
154+
image_embeds = self.mm_projector(encoder_output)
155+
if post_tokens is None:
156+
return torch.cat((pre_img_embed, image_embeds), dim=1)
157+
158+
post_img_embed = self.tok_embeddings(post_tokens)
159+
return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1)
160+
161+
46162
class ModelType(Enum):
47163
TextOnly = "text_only"
48164
Llama3_1 = "llama3_1"
49165
Flamingo = "flamingo"
166+
Llava = "llava"
50167

51168

52169
# Type for objects that can generate nn.Module instance
@@ -100,16 +217,30 @@ def _flamingo(cls):
100217
fusion_class=DeepFusionModel,
101218
)
102219

220+
@classmethod
221+
def _llava(cls):
222+
return cls(
223+
model_type=ModelType.Llava,
224+
modules={
225+
'encoder': clip_vision_encoder,
226+
'decoder': Transformer
227+
},
228+
fusion_class=ConcateFusion,
229+
)
230+
103231
@classmethod
104232
def get_recipe(cls, model_type):
105-
if model_type == ModelType.TextOnly:
106-
return cls._text_only()
107-
elif model_type == ModelType.Flamingo:
108-
return cls._flamingo()
109-
elif model_type == ModelType.Llama3_1:
110-
return cls._llama3_1()
111-
else:
112-
raise ValueError(f"Can not find the model recipe for {model_type}")
233+
match model_type:
234+
case ModelType.TextOnly:
235+
return cls._text_only()
236+
case ModelType.Flamingo:
237+
return cls._flamingo()
238+
case ModelType.Llama3_1:
239+
return cls._llama3_1()
240+
case ModelType.Llava:
241+
return cls._llava()
242+
case _:
243+
raise ValueError(f"Can not find the model recipe for {model_type}")
113244

114245

115246
@dataclass
@@ -329,7 +460,14 @@ def build_model(self) -> nn.Module:
329460
modules[name] = module_class(**config_args)
330461

331462
return recipe.fusion_class(**modules)
332-
463+
464+
def _replace_known_params(self, params):
465+
patterns = {"QuickGELUActivation()": QuickGELUActivation()}
466+
for key, value in params.items():
467+
if isinstance(value, Hashable) and value in patterns:
468+
params[key] = patterns[value]
469+
return params
470+
333471
@abstractmethod
334472
def forward(self, *args, **kwargs):
335473
raise NotImplementedError("forward method is not implemented")
@@ -414,11 +552,26 @@ def reset_caches(self):
414552
self.model.reset_caches()
415553

416554

555+
class LlavaModel(Model):
556+
def forward(
557+
self,
558+
tokens: Tensor,
559+
*,
560+
encoder_input: Optional[Dict[str, Tensor]] = None,
561+
post_tokens: Optional[Tensor] = None,
562+
input_pos: Optional[Tensor] = None,
563+
) -> Tensor:
564+
return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos)
565+
566+
def setup_caches(self, max_batch_size, max_seq_length):
567+
self.model.setup_caches(max_batch_size, max_seq_length)
568+
417569

418570
MODEL_TYPE_TO_CLASS = {
419571
ModelType.TextOnly: TextOnlyModel,
420572
ModelType.Flamingo: FlamingoModel,
421573
ModelType.Llama3_1: Llama31Model,
574+
ModelType.Llava: LlavaModel,
422575
}
423576

424577

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"model_type": "llava",
3+
"use_tiktoken": true,
4+
"encoder": {
5+
"tile_size": 336,
6+
"patch_size": 14,
7+
"embed_dim": 1024,
8+
"num_layers": 24,
9+
"num_heads": 16,
10+
"out_indices": [
11+
23
12+
],
13+
"output_cls_projection": false,
14+
"max_num_tiles": 1,
15+
"in_channels": 3,
16+
"intermediate_act": "QuickGELUActivation()"
17+
},
18+
"decoder": {
19+
"n_layers": 32,
20+
"n_heads": 32,
21+
"dim": 4096,
22+
"vocab_size": 32064,
23+
"max_seq_length": 768
24+
}
25+
}

0 commit comments

Comments
 (0)