Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit cbda879

Browse files
committed
llava model constuction support
1 parent a190b0f commit cbda879

File tree

1 file changed

+3
-45
lines changed

1 file changed

+3
-45
lines changed

torchchat/model.py

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def forward(
116116
encoder_mask: Optional[torch.Tensor] = None,
117117
input_pos: Optional[torch.Tensor] = None,
118118
) -> Tensor:
119-
if encoder_input:
119+
if encoder_input is not None:
120+
encoder_input = encoder_input.view(1, 1, *encoder_input.shape)
120121
encoder_output = self.encoder(
121122
encoder_input,
122123
)
@@ -223,7 +224,7 @@ def _llava(cls):
223224
'encoder': clip_vision_encoder,
224225
'decoder': Transformer
225226
},
226-
fusion_class=DeepFusionModel,
227+
fusion_class=ConcateFusion,
227228
)
228229

229230
@classmethod
@@ -968,46 +969,3 @@ def setup_caches(self, max_batch_size, max_seq_length):
968969

969970
except:
970971
pass
971-
972-
973-
if __name__ == "__main__":
974-
def prepare_image(target_h: int, target_w: int) -> torch.Tensor:
975-
"""Read image into a tensor and resize the image so that it fits in
976-
a target_h x target_w canvas.
977-
978-
Args:
979-
image (Image): An Image object.
980-
target_h (int): Target height.
981-
target_w (int): Target width.
982-
983-
Returns:
984-
torch.Tensor: resized image tensor.
985-
"""
986-
image = Image.open(
987-
requests.get(
988-
"https://llava-vl.github.io/static/images/view.jpg", stream=True
989-
).raw)
990-
991-
img = torchvision.transforms.functional.pil_to_tensor(image)
992-
# height ratio
993-
ratio_h = img.shape[1] / target_h
994-
# width ratio
995-
ratio_w = img.shape[2] / target_w
996-
# resize the image so that it fits in a target_h x target_w canvas
997-
ratio = max(ratio_h, ratio_w)
998-
output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio))
999-
img = torchvision.transforms.Resize(size=output_size)(img)
1000-
return img
1001-
1002-
pre_tokens = torch.tensor([[ 1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116,
1003-
21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892,
1004-
322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155,
1005-
29889, 3148, 1001, 29901, 29871]])
1006-
img = prepare_image(336, 336)
1007-
post_tokens = torch.tensor([[29871, 13, 462, 9651, 1724, 526, 278, 2712, 306, 881,
1008-
367, 274, 1300, 2738, 1048, 746, 306, 6493, 1244, 29973,
1009-
319, 1799, 9047, 13566, 29901]])
1010-
1011-
llava_model = Model.from_params("/home/gasoonjia/torchchat/torchchat/model_params/llava-1.5.json")
1012-
1013-
llava_model(tokens=pre_tokens, encoder_input=img, post_tokens=post_tokens)

0 commit comments

Comments
 (0)