|
12 | 12 | from enum import Enum |
13 | 13 | from pathlib import Path |
14 | 14 |
|
| 15 | +import torchvision |
| 16 | + |
15 | 17 | from typing import Any, Callable, Dict, Optional, Union |
| 18 | +from collections.abc import Hashable |
16 | 19 |
|
17 | 20 | import torch |
18 | 21 | import torch.nn as nn |
|
31 | 34 | from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder |
32 | 35 | from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder |
33 | 36 | from torchtune.modules.model_fusion import DeepFusionModel |
| 37 | +from torchtune.models.clip import clip_vision_encoder |
34 | 38 |
|
35 | 39 | from torchchat.utils.build_utils import find_multiple, get_precision |
36 | 40 |
|
37 | 41 | config_path = Path(f"{str(Path(__file__).parent)}/model_params") |
38 | 42 |
|
39 | 43 |
|
| 44 | +class QuickGELUActivation(nn.Module): |
| 45 | + """ |
| 46 | + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs |
| 47 | + """ |
| 48 | + |
| 49 | + def forward(self, input): |
| 50 | + return input * torch.sigmoid(1.702 * input) |
| 51 | + |
| 52 | + |
40 | 53 | def identity(**kwargs): |
41 | 54 | if len(kwargs) != 1: |
42 | 55 | raise ValueError("Only one argument is expected") |
43 | 56 | return list(kwargs.values())[0] |
44 | 57 |
|
45 | 58 |
|
| 59 | + |
| 60 | +class MultiModalProjector(nn.Module): |
| 61 | + def __init__(self, in_channels: int, out_channels: int, act: nn.Module): |
| 62 | + super().__init__() |
| 63 | + |
| 64 | + self.linear_1 = nn.Linear(in_channels, out_channels, bias=True) |
| 65 | + self.act = act |
| 66 | + self.linear_2 = nn.Linear(out_channels, out_channels, bias=True) |
| 67 | + |
| 68 | + def forward(self, image_features): |
| 69 | + hidden_states = self.linear_1(image_features) |
| 70 | + hidden_states = self.act(hidden_states) |
| 71 | + hidden_states = self.linear_2(hidden_states) |
| 72 | + return hidden_states |
| 73 | + |
| 74 | + |
| 75 | +class ConcateFusion(nn.Module): |
| 76 | + def __init__( |
| 77 | + self, |
| 78 | + encoder: nn.Module, |
| 79 | + decoder: nn.Module, |
| 80 | + token_embedding_name="tok_embeddings", |
| 81 | + mm_proj_in_channels=1024, |
| 82 | + mm_proj_out_channels=4096, |
| 83 | + mm_proj_activation=nn.GELU(), |
| 84 | + ): |
| 85 | + super().__init__() |
| 86 | + self.encoder = encoder |
| 87 | + self.decoder = decoder |
| 88 | + |
| 89 | + # esclate the embedding layer outside decoder llava model need to fuse |
| 90 | + # the text and image embedding together before passing to decoder. |
| 91 | + self.tok_embeddings = getattr(self.decoder, token_embedding_name) |
| 92 | + |
| 93 | + # set the embedding layer in decoder to None to jump the embedding layer over in decoder |
| 94 | + self.decoder.__setattr__(token_embedding_name, None) |
| 95 | + |
| 96 | + self.mm_projector = MultiModalProjector( |
| 97 | + in_channels=mm_proj_in_channels, |
| 98 | + out_channels=mm_proj_out_channels, |
| 99 | + act=mm_proj_activation, |
| 100 | + ) |
| 101 | + |
| 102 | + def forward( |
| 103 | + self, |
| 104 | + tokens: Tensor, |
| 105 | + *, |
| 106 | + post_tokens: Optional[Tensor] = None, |
| 107 | + encoder_input: Optional[Tensor] = None, |
| 108 | + encoder_mask: Optional[torch.Tensor] = None, |
| 109 | + input_pos: Optional[torch.Tensor] = None, |
| 110 | + ) -> Tensor: |
| 111 | + if encoder_input is not None: |
| 112 | + encoder_input = encoder_input.view(1, 1, *encoder_input.shape) |
| 113 | + encoder_output = self.encoder(encoder_input) |
| 114 | + encoder_output = self._encoder_feature_select(encoder_output) |
| 115 | + else: |
| 116 | + encoder_output = None |
| 117 | + |
| 118 | + decoder_input = self._get_decoder_input( |
| 119 | + tokens, encoder_output=encoder_output, post_tokens=post_tokens |
| 120 | + ) |
| 121 | + |
| 122 | + if input_pos is None: |
| 123 | + input_pos = torch.arange( |
| 124 | + decoder_input.shape[1], |
| 125 | + device=decoder_input.device, |
| 126 | + dtype=torch.int, |
| 127 | + ) |
| 128 | + |
| 129 | + return self.decoder(decoder_input, input_pos=input_pos) |
| 130 | + |
| 131 | + def setup_caches(self, batch_size, max_seq_len) -> None: |
| 132 | + self.decoder.setup_caches(batch_size, max_seq_len) |
| 133 | + |
| 134 | + def _encoder_feature_select(self, encoder_output) -> Tensor: |
| 135 | + selected_image_feature = encoder_output[1][0].view( |
| 136 | + *encoder_output[1][0].shape[2:] |
| 137 | + ) |
| 138 | + |
| 139 | + selected_image_feature = selected_image_feature[:, 1:] |
| 140 | + return selected_image_feature |
| 141 | + |
| 142 | + def _get_decoder_input( |
| 143 | + self, |
| 144 | + tokens: Tensor, |
| 145 | + *, |
| 146 | + encoder_output: Optional[Tensor], |
| 147 | + post_tokens: Optional[Tensor], |
| 148 | + ) -> Tensor: |
| 149 | + if encoder_output is None: |
| 150 | + assert post_tokens is None |
| 151 | + return self.tok_embeddings(tokens) |
| 152 | + else: |
| 153 | + pre_img_embed = self.tok_embeddings(tokens) |
| 154 | + image_embeds = self.mm_projector(encoder_output) |
| 155 | + if post_tokens is None: |
| 156 | + return torch.cat((pre_img_embed, image_embeds), dim=1) |
| 157 | + |
| 158 | + post_img_embed = self.tok_embeddings(post_tokens) |
| 159 | + return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1) |
| 160 | + |
| 161 | + |
46 | 162 | class ModelType(Enum): |
47 | 163 | TextOnly = "text_only" |
48 | 164 | Llama3_1 = "llama3_1" |
49 | 165 | Flamingo = "flamingo" |
| 166 | + Llava = "llava" |
50 | 167 |
|
51 | 168 |
|
52 | 169 | # Type for objects that can generate nn.Module instance |
@@ -100,16 +217,30 @@ def _flamingo(cls): |
100 | 217 | fusion_class=DeepFusionModel, |
101 | 218 | ) |
102 | 219 |
|
| 220 | + @classmethod |
| 221 | + def _llava(cls): |
| 222 | + return cls( |
| 223 | + model_type=ModelType.Llava, |
| 224 | + modules={ |
| 225 | + 'encoder': clip_vision_encoder, |
| 226 | + 'decoder': Transformer |
| 227 | + }, |
| 228 | + fusion_class=ConcateFusion, |
| 229 | + ) |
| 230 | + |
103 | 231 | @classmethod |
104 | 232 | def get_recipe(cls, model_type): |
105 | | - if model_type == ModelType.TextOnly: |
106 | | - return cls._text_only() |
107 | | - elif model_type == ModelType.Flamingo: |
108 | | - return cls._flamingo() |
109 | | - elif model_type == ModelType.Llama3_1: |
110 | | - return cls._llama3_1() |
111 | | - else: |
112 | | - raise ValueError(f"Can not find the model recipe for {model_type}") |
| 233 | + match model_type: |
| 234 | + case ModelType.TextOnly: |
| 235 | + return cls._text_only() |
| 236 | + case ModelType.Flamingo: |
| 237 | + return cls._flamingo() |
| 238 | + case ModelType.Llama3_1: |
| 239 | + return cls._llama3_1() |
| 240 | + case ModelType.Llava: |
| 241 | + return cls._llava() |
| 242 | + case _: |
| 243 | + raise ValueError(f"Can not find the model recipe for {model_type}") |
113 | 244 |
|
114 | 245 |
|
115 | 246 | @dataclass |
@@ -329,7 +460,14 @@ def build_model(self) -> nn.Module: |
329 | 460 | modules[name] = module_class(**config_args) |
330 | 461 |
|
331 | 462 | return recipe.fusion_class(**modules) |
332 | | - |
| 463 | + |
| 464 | + def _replace_known_params(self, params): |
| 465 | + patterns = {"QuickGELUActivation()": QuickGELUActivation()} |
| 466 | + for key, value in params.items(): |
| 467 | + if isinstance(value, Hashable) and value in patterns: |
| 468 | + params[key] = patterns[value] |
| 469 | + return params |
| 470 | + |
333 | 471 | @abstractmethod |
334 | 472 | def forward(self, *args, **kwargs): |
335 | 473 | raise NotImplementedError("forward method is not implemented") |
@@ -414,11 +552,26 @@ def reset_caches(self): |
414 | 552 | self.model.reset_caches() |
415 | 553 |
|
416 | 554 |
|
| 555 | +class LlavaModel(Model): |
| 556 | + def forward( |
| 557 | + self, |
| 558 | + tokens: Tensor, |
| 559 | + *, |
| 560 | + encoder_input: Optional[Dict[str, Tensor]] = None, |
| 561 | + post_tokens: Optional[Tensor] = None, |
| 562 | + input_pos: Optional[Tensor] = None, |
| 563 | + ) -> Tensor: |
| 564 | + return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos) |
| 565 | + |
| 566 | + def setup_caches(self, max_batch_size, max_seq_length): |
| 567 | + self.model.setup_caches(max_batch_size, max_seq_length) |
| 568 | + |
417 | 569 |
|
418 | 570 | MODEL_TYPE_TO_CLASS = { |
419 | 571 | ModelType.TextOnly: TextOnlyModel, |
420 | 572 | ModelType.Flamingo: FlamingoModel, |
421 | 573 | ModelType.Llama3_1: Llama31Model, |
| 574 | + ModelType.Llava: LlavaModel, |
422 | 575 | } |
423 | 576 |
|
424 | 577 |
|
|
0 commit comments