Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 353fafe

Browse files
committed
3/n llava init
1 parent 2cabbe7 commit 353fafe

File tree

2 files changed

+58
-11
lines changed

2 files changed

+58
-11
lines changed

torchchat/model.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,40 +42,64 @@ def identity(**kwargs):
4242
return list(kwargs.values())[0]
4343

4444

45+
class MultiModalProjector(nn.Module):
46+
def __init__(self, args: ProjectorArgs):
47+
super().__init__()
48+
49+
self.linear_1 = nn.Linear(args.in_channels, args.out_channels, bias=True)
50+
self.act = args.activation
51+
self.linear_2 = nn.Linear(args.out_channels, args.out_channels, bias=True)
52+
53+
def forward(self, image_features):
54+
hidden_states = self.linear_1(image_features)
55+
hidden_states = self.act(hidden_states)
56+
hidden_states = self.linear_2(hidden_states)
57+
return hidden_states
58+
4559
class ConcateFusion(nn.Module):
46-
def __init__(self, encoder: nn.Module, decoder: nn.Module):
60+
def __init__(self, encoder: nn.Module, decoder: nn.Module, token_embedding_name="tok_embeddings", mm_proj_in_channels=1024, mm_proj_out_channels=4096, mm_proj_activation=nn.GELU):
4761
super().__init__()
4862
self.encoder = encoder
4963
self.decoder = decoder
5064

65+
# esclate the embedding layer outside decoder llava model need to fuse
66+
# the text and image embedding together before passing to decoder.
67+
self.tok_embeddings = getattr(self.decoder, token_embedding_name)
68+
69+
# set the embedding layer in decoder to None to jump the embedding layer over in decoder
70+
self.decoder.__setattr__(token_embedding_name) = None
71+
72+
self.mm_projector = MultiModalProjector(ProjectorArgs(in_channels=mm_proj_in_channels, out_channels=mm_proj_out_channels, activation=mm_proj_activation))
73+
5174
def forward(self,
5275
tokens: Tensor,
5376
*,
5477
post_tokens: Optional[Tensor] = None,
55-
mask: Optional[torch.Tensor] = None,
5678
encoder_input: Optional[Tensor] = None,
5779
encoder_mask: Optional[torch.Tensor] = None,
5880
input_pos: Optional[torch.Tensor] = None,) -> Tensor:
59-
# split prompt from img tag into before img and after img
60-
# concate before img, image result and after img into a large prompt
61-
# forward that to text transformer
62-
# resturn result
63-
6481
if encoder_input:
6582
encoder_output = self.encoder(
6683
encoder_input,
6784
)
85+
else:
86+
encoder_output = None
87+
88+
decoder_input = self._get_decoder_input(tokens, encoder_input=encoder_input, post_tokens=post_tokens)
89+
return self.decoder(decoder_input)
6890

69-
def _gen_mm_embedding(self, tokens: Tensor, *, encoder_input: Optional[Tensor], post_tokens: Optional[Tensor]):
91+
def _get_decoder_input(self, tokens: Tensor, *, encoder_input: Optional[Tensor], post_tokens: Optional[Tensor]):
7092
assert bool(encoder_input) == bool(post_tokens), "encoder_input and post_tokens must be both None or not None"
7193
if encoder_input is None:
72-
return tokens
94+
return self.tok_embeddings(tokens)
95+
else:
96+
pre_img_embed = self.tok_embeddings(tokens)
97+
post_img_embed = self.tok_embeddings(post_tokens)
98+
return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1)
7399

74100

75101

76102

77-
78-
79103
class ModelType(Enum):
80104
TextOnly = "text_only"
81105
Llama3_1 = "llama3_1"
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
@dataclass
2+
class VisionArgs:
3+
tile_size: int = 336
4+
patch_size: int = 14
5+
embed_dim: int = 1024
6+
num_layers: int = 24
7+
num_heads: int = 16
8+
out_indices: List[int] = field(default_factory=list)
9+
output_cls_projection: bool = False
10+
max_num_tiles: int = 1
11+
in_channels: int = 3
12+
intermediate_act: nn.Module = QuickGELUActivation()
13+
14+
def __post_init__(self):
15+
if not self.out_indices:
16+
self.out_indices = [self.num_layers - 1]
17+
18+
19+
@dataclass
20+
class ProjectorArgs:
21+
in_channels: int = 1024
22+
out_channels: int = 4096
23+
activation: nn.Module = nn.GELU()

0 commit comments

Comments
 (0)