Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 215331d

Browse files
committed
3/n llava
1 parent 728fc46 commit 215331d

File tree

2 files changed

+49
-29
lines changed

2 files changed

+49
-29
lines changed

torchchat/model.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,20 @@
2828
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
2929
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
3030
from torchtune.modules.model_fusion import DeepFusionModel
31+
from torchtune.models.clip import clip_vision_encoder
3132

3233
config_path = Path(f"{str(Path(__file__).parent)}/model_params")
3334

3435

36+
class QuickGELUActivation(nn.Module):
37+
"""
38+
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
39+
"""
40+
41+
def forward(self, input):
42+
return input * torch.sigmoid(1.702 * input)
43+
44+
3545
def identity(**kwargs):
3646
if len(kwargs) != 1:
3747
raise ValueError("Only one argument is expected")
@@ -99,24 +109,25 @@ def forward(
99109
encoder_output = None
100110

101111
decoder_input = self._get_decoder_input(
102-
tokens, encoder_input=encoder_input, post_tokens=post_tokens
112+
tokens, encoder_output=encoder_output, post_tokens=post_tokens
103113
)
104114
return self.decoder(decoder_input)
105115

106116
def _get_decoder_input(
107117
self,
108118
tokens: Tensor,
109119
*,
110-
encoder_input: Optional[Tensor],
120+
encoder_output: Optional[Tensor],
111121
post_tokens: Optional[Tensor],
112122
):
113-
assert bool(encoder_input) == bool(
123+
assert bool(encoder_output) == bool(
114124
post_tokens
115125
), "encoder_input and post_tokens must be both None or not None"
116-
if encoder_input is None:
126+
if encoder_output is None:
117127
return self.tok_embeddings(tokens)
118128
else:
119129
pre_img_embed = self.tok_embeddings(tokens)
130+
image_embeds = self.mm_projector(encoder_output)
120131
post_img_embed = self.tok_embeddings(post_tokens)
121132
return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1)
122133

@@ -261,7 +272,7 @@ class ModelArgs:
261272

262273
def __init__(
263274
self,
264-
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]],
275+
transformer_args: Union[TransformerArgs, Dict[str, Dict[str, Any]]],
265276
model_type: ModelType = ModelType.TextOnly,
266277
) -> None:
267278
self._sanity_check(transformer_args, model_type)
@@ -275,7 +286,7 @@ def __init__(
275286

276287
def _sanity_check(
277288
self,
278-
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]],
289+
transformer_args: Union[TransformerArgs, Dict[str, Dict[str, Any]]],
279290
model_type: ModelType,
280291
) -> None:
281292
assert isinstance(model_type, ModelType)
@@ -393,12 +404,20 @@ def build_model(self) -> nn.Module:
393404
modules = {}
394405
for name, module_class in recipe.modules.items():
395406
if isinstance(config_args := self.config.transformer_args[name], dict):
407+
config_args = self._replace_know_params(config_args)
396408
modules[name] = module_class(**config_args)
397409
else:
398410
modules[name] = module_class(config_args)
399411

400412
return recipe.fusion_class(**modules)
401413

414+
def _replace_know_params(self, params):
415+
patterns = {"QuickGELUActivation()": QuickGELUActivation(), "False": False, "True": True}
416+
for key, value in params.items():
417+
if value in patterns:
418+
params[key] = patterns[value]
419+
return params
420+
402421
@abstractmethod
403422
def forward(self, *args, **kwargs):
404423
raise NotImplementedError("forward method is not implemented")
Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
1-
@dataclass
2-
class VisionArgs:
3-
tile_size: int = 336
4-
patch_size: int = 14
5-
embed_dim: int = 1024
6-
num_layers: int = 24
7-
num_heads: int = 16
8-
out_indices: List[int] = field(default_factory=list)
9-
output_cls_projection: bool = False
10-
max_num_tiles: int = 1
11-
in_channels: int = 3
12-
intermediate_act: nn.Module = QuickGELUActivation()
13-
14-
def __post_init__(self):
15-
if not self.out_indices:
16-
self.out_indices = [self.num_layers - 1]
17-
18-
19-
@dataclass
20-
class ProjectorArgs:
21-
in_channels: int = 1024
22-
out_channels: int = 4096
23-
activation: nn.Module = nn.GELU()
1+
{
2+
"model_type": "llava",
3+
"encoder": {
4+
"tile_size": 336,
5+
"patch_size": 14,
6+
"embed_dim": 1024,
7+
"num_layers": 24,
8+
"num_heads": 16,
9+
"out_indices": [
10+
23
11+
],
12+
"output_cls_projection": False,
13+
"max_num_tiles": 1,
14+
"in_channels": 3,
15+
"intermediate_act": QuickGELUActivation()
16+
},
17+
"decoder": {
18+
"n_layers": 32,
19+
"n_heads": 32,
20+
"dim": 4096,
21+
"vocab_size": 32064,
22+
"max_seq_length": 768
23+
}
24+
}

0 commit comments

Comments
 (0)