|
6 | 6 | from dataclasses import dataclass |
7 | 7 | from enum import Enum |
8 | 8 | from pathlib import Path |
| 9 | +from PIL import Image |
| 10 | +import requests |
| 11 | +import torchvision |
9 | 12 |
|
10 | 13 | from typing import Any, Callable, Dict, Optional, Union |
| 14 | +from collections.abc import Hashable |
11 | 15 |
|
12 | 16 | import torch |
13 | 17 | import torch.nn as nn |
@@ -48,6 +52,13 @@ def identity(**kwargs): |
48 | 52 | return list(kwargs.values())[0] |
49 | 53 |
|
50 | 54 |
|
| 55 | +@dataclass |
| 56 | +class ProjectorArgs: |
| 57 | + in_channels: int = 1024 |
| 58 | + out_channels: int = 4096 |
| 59 | + activation: nn.Module = nn.GELU() |
| 60 | + |
| 61 | + |
51 | 62 | class MultiModalProjector(nn.Module): |
52 | 63 | def __init__(self, args: ProjectorArgs): |
53 | 64 | super().__init__() |
@@ -105,13 +116,22 @@ def forward( |
105 | 116 | encoder_output = self.encoder( |
106 | 117 | encoder_input, |
107 | 118 | ) |
| 119 | + encoder_output = self._encoder_feature_select(encoder_output) |
108 | 120 | else: |
109 | 121 | encoder_output = None |
110 | 122 |
|
111 | 123 | decoder_input = self._get_decoder_input( |
112 | 124 | tokens, encoder_output=encoder_output, post_tokens=post_tokens |
113 | 125 | ) |
114 | | - return self.decoder(decoder_input) |
| 126 | + return self.decoder(decoder_input, input_pos=input_pos) |
| 127 | + |
| 128 | + def _encoder_feature_select(self, encoder_output): |
| 129 | + selected_image_feature = encoder_output[1][0].view( |
| 130 | + *encoder_output[1][0].shape[2:] |
| 131 | + ) |
| 132 | + |
| 133 | + selected_image_feature = selected_image_feature[:, 1:] |
| 134 | + return selected_image_feature |
115 | 135 |
|
116 | 136 | def _get_decoder_input( |
117 | 137 | self, |
@@ -197,8 +217,8 @@ def _llava(cls): |
197 | 217 | return cls( |
198 | 218 | model_type=ModelType.Llava, |
199 | 219 | modules={ |
200 | | - 'te': flamingo_vision_encoder, |
201 | | - 'decoder': llama3_1_builder |
| 220 | + 'encoder': clip_vision_encoder, |
| 221 | + 'decoder': Transformer |
202 | 222 | }, |
203 | 223 | fusion_class=DeepFusionModel, |
204 | 224 | ) |
@@ -414,7 +434,7 @@ def build_model(self) -> nn.Module: |
414 | 434 | def _replace_know_params(self, params): |
415 | 435 | patterns = {"QuickGELUActivation()": QuickGELUActivation(), "False": False, "True": True} |
416 | 436 | for key, value in params.items(): |
417 | | - if value in patterns: |
| 437 | + if isinstance(value, Hashable) and value in patterns: |
418 | 438 | params[key] = patterns[value] |
419 | 439 | return params |
420 | 440 |
|
@@ -496,10 +516,26 @@ def reset_caches(self): |
496 | 516 | self.model.reset_caches() |
497 | 517 |
|
498 | 518 |
|
| 519 | +class LlavaModel(Model): |
| 520 | + def forward( |
| 521 | + self, |
| 522 | + tokens: Tensor, |
| 523 | + *, |
| 524 | + encoder_input: Optional[Dict[str, Tensor]] = None, |
| 525 | + post_tokens: Optional[Tensor] = None, |
| 526 | + input_pos: Optional[Tensor] = None, |
| 527 | + ) -> Tensor: |
| 528 | + return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos) |
| 529 | + |
| 530 | + def setup_caches(self, max_batch_size, max_seq_length): |
| 531 | + self.model.setup_caches(max_batch_size, max_seq_length) |
| 532 | + |
| 533 | + |
499 | 534 | MODEL_TYPE_TO_CLASS = { |
500 | 535 | ModelType.TextOnly: TextOnlyModel, |
501 | 536 | ModelType.Flamingo: FlamingoModel, |
502 | 537 | ModelType.Llama3_1: Llama31Model, |
| 538 | + ModelType.Llava: LlavaModel, |
503 | 539 | } |
504 | 540 |
|
505 | 541 | class Transformer(nn.Module): |
@@ -882,3 +918,46 @@ def setup_caches(self, max_batch_size, max_seq_length): |
882 | 918 |
|
883 | 919 | except: |
884 | 920 | pass |
| 921 | + |
| 922 | + |
| 923 | +if __name__ == "__main__": |
| 924 | + def prepare_image(target_h: int, target_w: int) -> torch.Tensor: |
| 925 | + """Read image into a tensor and resize the image so that it fits in |
| 926 | + a target_h x target_w canvas. |
| 927 | +
|
| 928 | + Args: |
| 929 | + image (Image): An Image object. |
| 930 | + target_h (int): Target height. |
| 931 | + target_w (int): Target width. |
| 932 | +
|
| 933 | + Returns: |
| 934 | + torch.Tensor: resized image tensor. |
| 935 | + """ |
| 936 | + image = Image.open( |
| 937 | + requests.get( |
| 938 | + "https://llava-vl.github.io/static/images/view.jpg", stream=True |
| 939 | + ).raw) |
| 940 | + |
| 941 | + img = torchvision.transforms.functional.pil_to_tensor(image) |
| 942 | + # height ratio |
| 943 | + ratio_h = img.shape[1] / target_h |
| 944 | + # width ratio |
| 945 | + ratio_w = img.shape[2] / target_w |
| 946 | + # resize the image so that it fits in a target_h x target_w canvas |
| 947 | + ratio = max(ratio_h, ratio_w) |
| 948 | + output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio)) |
| 949 | + img = torchvision.transforms.Resize(size=output_size)(img) |
| 950 | + return img |
| 951 | + |
| 952 | + pre_tokens = torch.tensor([[ 1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, |
| 953 | + 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, |
| 954 | + 322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155, |
| 955 | + 29889, 3148, 1001, 29901, 29871]]) |
| 956 | + img = prepare_image(336, 336) |
| 957 | + post_tokens = torch.tensor([[29871, 13, 462, 9651, 1724, 526, 278, 2712, 306, 881, |
| 958 | + 367, 274, 1300, 2738, 1048, 746, 306, 6493, 1244, 29973, |
| 959 | + 319, 1799, 9047, 13566, 29901]]) |
| 960 | + |
| 961 | + llava_model = Model.from_params("/home/gasoonjia/torchchat/torchchat/model_params/llava-1.5.json") |
| 962 | + |
| 963 | + llava_model(tokens=pre_tokens, encoder_input=img, post_tokens=post_tokens) |
0 commit comments