Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 22fd2a5

Browse files
committed
4/n llava init
1 parent 23d6504 commit 22fd2a5

File tree

1 file changed

+83
-4
lines changed

1 file changed

+83
-4
lines changed

torchchat/model.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
from dataclasses import dataclass
77
from enum import Enum
88
from pathlib import Path
9+
from PIL import Image
10+
import requests
11+
import torchvision
912

1013
from typing import Any, Callable, Dict, Optional, Union
14+
from collections.abc import Hashable
1115

1216
import torch
1317
import torch.nn as nn
@@ -48,6 +52,13 @@ def identity(**kwargs):
4852
return list(kwargs.values())[0]
4953

5054

55+
@dataclass
56+
class ProjectorArgs:
57+
in_channels: int = 1024
58+
out_channels: int = 4096
59+
activation: nn.Module = nn.GELU()
60+
61+
5162
class MultiModalProjector(nn.Module):
5263
def __init__(self, args: ProjectorArgs):
5364
super().__init__()
@@ -105,13 +116,22 @@ def forward(
105116
encoder_output = self.encoder(
106117
encoder_input,
107118
)
119+
encoder_output = self._encoder_feature_select(encoder_output)
108120
else:
109121
encoder_output = None
110122

111123
decoder_input = self._get_decoder_input(
112124
tokens, encoder_output=encoder_output, post_tokens=post_tokens
113125
)
114-
return self.decoder(decoder_input)
126+
return self.decoder(decoder_input, input_pos=input_pos)
127+
128+
def _encoder_feature_select(self, encoder_output):
129+
selected_image_feature = encoder_output[1][0].view(
130+
*encoder_output[1][0].shape[2:]
131+
)
132+
133+
selected_image_feature = selected_image_feature[:, 1:]
134+
return selected_image_feature
115135

116136
def _get_decoder_input(
117137
self,
@@ -197,8 +217,8 @@ def _llava(cls):
197217
return cls(
198218
model_type=ModelType.Llava,
199219
modules={
200-
'te': flamingo_vision_encoder,
201-
'decoder': llama3_1_builder
220+
'encoder': clip_vision_encoder,
221+
'decoder': Transformer
202222
},
203223
fusion_class=DeepFusionModel,
204224
)
@@ -414,7 +434,7 @@ def build_model(self) -> nn.Module:
414434
def _replace_know_params(self, params):
415435
patterns = {"QuickGELUActivation()": QuickGELUActivation(), "False": False, "True": True}
416436
for key, value in params.items():
417-
if value in patterns:
437+
if isinstance(value, Hashable) and value in patterns:
418438
params[key] = patterns[value]
419439
return params
420440

@@ -496,10 +516,26 @@ def reset_caches(self):
496516
self.model.reset_caches()
497517

498518

519+
class LlavaModel(Model):
520+
def forward(
521+
self,
522+
tokens: Tensor,
523+
*,
524+
encoder_input: Optional[Dict[str, Tensor]] = None,
525+
post_tokens: Optional[Tensor] = None,
526+
input_pos: Optional[Tensor] = None,
527+
) -> Tensor:
528+
return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos)
529+
530+
def setup_caches(self, max_batch_size, max_seq_length):
531+
self.model.setup_caches(max_batch_size, max_seq_length)
532+
533+
499534
MODEL_TYPE_TO_CLASS = {
500535
ModelType.TextOnly: TextOnlyModel,
501536
ModelType.Flamingo: FlamingoModel,
502537
ModelType.Llama3_1: Llama31Model,
538+
ModelType.Llava: LlavaModel,
503539
}
504540

505541
class Transformer(nn.Module):
@@ -882,3 +918,46 @@ def setup_caches(self, max_batch_size, max_seq_length):
882918

883919
except:
884920
pass
921+
922+
923+
if __name__ == "__main__":
924+
def prepare_image(target_h: int, target_w: int) -> torch.Tensor:
925+
"""Read image into a tensor and resize the image so that it fits in
926+
a target_h x target_w canvas.
927+
928+
Args:
929+
image (Image): An Image object.
930+
target_h (int): Target height.
931+
target_w (int): Target width.
932+
933+
Returns:
934+
torch.Tensor: resized image tensor.
935+
"""
936+
image = Image.open(
937+
requests.get(
938+
"https://llava-vl.github.io/static/images/view.jpg", stream=True
939+
).raw)
940+
941+
img = torchvision.transforms.functional.pil_to_tensor(image)
942+
# height ratio
943+
ratio_h = img.shape[1] / target_h
944+
# width ratio
945+
ratio_w = img.shape[2] / target_w
946+
# resize the image so that it fits in a target_h x target_w canvas
947+
ratio = max(ratio_h, ratio_w)
948+
output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio))
949+
img = torchvision.transforms.Resize(size=output_size)(img)
950+
return img
951+
952+
pre_tokens = torch.tensor([[ 1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116,
953+
21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892,
954+
322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155,
955+
29889, 3148, 1001, 29901, 29871]])
956+
img = prepare_image(336, 336)
957+
post_tokens = torch.tensor([[29871, 13, 462, 9651, 1724, 526, 278, 2712, 306, 881,
958+
367, 274, 1300, 2738, 1048, 746, 306, 6493, 1244, 29973,
959+
319, 1799, 9047, 13566, 29901]])
960+
961+
llava_model = Model.from_params("/home/gasoonjia/torchchat/torchchat/model_params/llava-1.5.json")
962+
963+
llava_model(tokens=pre_tokens, encoder_input=img, post_tokens=post_tokens)

0 commit comments

Comments
 (0)