Skip to content

Commit b28300c

Browse files
committed
Pixtral: Refactor vision model, update example
1 parent 7c876ef commit b28300c

File tree

6 files changed

+148
-69
lines changed

6 files changed

+148
-69
lines changed

exllamav2/generator/dynamic_embeddings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class ExLlamaV2MMEmbedding:
3535
first_index: int
3636
length: int
3737

38+
metadata: dict
39+
3840
def __init__(
3941
self,
4042
model: ExLlamaV2,
@@ -57,6 +59,7 @@ def __init__(
5759
self.embeddings = embeddings
5860
self.text_alias = text_alias
5961
self.model = model
62+
self.metadata = {}
6063

6164
self.length = embeddings.shape[0]
6265
dim = embeddings.shape[1]

exllamav2/vlm/preprocessor/pixtral.py renamed to exllamav2/vlm/processor/pixtral.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import numpy as np
33
from PIL import Image
4+
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
45
from exllamav2.config import ExLlamaV2Config
56
from exllamav2.vlm.util import (
67
convert_to_rgb,
@@ -41,4 +42,32 @@ def preprocess(
4142

4243
image = image.transpose(2, 0, 1)
4344
image = torch.from_numpy(image).half()
44-
return image
45+
return image
46+
47+
def postprocess(
48+
model: ExLlamaV2,
49+
tokenizer: ExLlamaV2Tokenizer,
50+
embeddings: torch.Tensor,
51+
features_y: int,
52+
features_x: int,
53+
):
54+
"""
55+
Insert [IMG_BREAK] and [IMG_END] tokens in image feature embeddings
56+
"""
57+
58+
assert embeddings.shape[0] == features_y * features_x, \
59+
"Invalid shape for embeddings"
60+
61+
id_break = tokenizer.single_id("[IMG_BREAK]")
62+
id_end = tokenizer.single_id("[IMG_END]")
63+
img_break = model.modules[0].forward(torch.tensor([id_break], dtype=torch.long)).to("cuda:0")
64+
img_end = model.modules[0].forward(torch.tensor([id_end], dtype=torch.long)).to("cuda:0")
65+
66+
dim = embeddings.shape[-1]
67+
embeddings = embeddings.view((features_y, features_x, dim))
68+
break_col = img_break.expand(features_y, -1, -1)
69+
embeddings = torch.cat((embeddings, break_col), dim = 1)
70+
embeddings = embeddings.view((features_y * (features_x + 1)), dim)
71+
embeddings = torch.cat((embeddings, img_end), dim = 0)
72+
73+
return embeddings

exllamav2/vlm/vision_tower.py

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@
44
import threading
55

66
import torch
7-
from exllamav2 import ExLlamaV2
7+
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
88
from exllamav2.conv2d import ExLlamaV2Conv2D
99
from exllamav2.rmsnorm import ExLlamaV2RMSNorm
1010
from exllamav2.attn import ExLlamaV2Attention
1111
from exllamav2.mlp import ExLlamaV2MLP
1212
from exllamav2.config import ExLlamaV2Config
1313
from exllamav2.module import ExLlamaV2Module
14-
from exllamav2.vlm.preprocessor import pixtral
14+
from exllamav2.vlm.processor import pixtral
1515
from exllamav2.compat import safe_move_tensor
16+
from exllamav2.generator import ExLlamaV2MMEmbedding
17+
from typing import Callable
1618

1719
from PIL.Image import Image
1820
from exllamav2.vlm.util import position_ids_in_meshgrid
@@ -35,7 +37,8 @@ def __init__(
3537
# Preprocessor
3638

3739
if cfg.vision_model_type == "pixtral":
38-
self.preprocessor = pixtral.preprocess
40+
self.preprocess_func = pixtral.preprocess
41+
self.postprocess_func = pixtral.postprocess
3942
else:
4043
raise ValueError(f"Unknown vision model type: {cfg.vision_model_type}")
4144

@@ -90,16 +93,34 @@ def __init__(
9093
mlp = ExLlamaV2MLP(self, layer_key, layer_idx, archparams = self.archparams)
9194
self.modules += [attn, mlp]
9295

96+
# Multimodal projection
97+
98+
mmp = ExLlamaV2MLP(
99+
self,
100+
cfg.arch.mmp_prefix,
101+
0,
102+
archparams = cfg.arch.mmp,
103+
in_features = cfg.vision_hidden_size,
104+
out_features = cfg.hidden_size,
105+
interm_features = cfg.hidden_size,
106+
has_norm = False,
107+
has_residual = False
108+
)
109+
self.modules += [mmp]
110+
93111

94112
def forward(self, **kwargs):
95113
raise NotImplementedError()
96-
97-
98-
def preprocess(self, image: Image) -> torch.Tensor:
99-
"""
100-
Preprocess image and prepare for vision tower
101-
"""
102-
return self.preprocessor(self.config, image)
114+
def forward_chunk(self, **kwargs):
115+
raise NotImplementedError()
116+
def load_tp(self, **kwargs):
117+
raise ValueError("load_tp not supported for vision model")
118+
def load_tp_gen(self, **kwargs):
119+
raise ValueError("load_tp not supported for vision model")
120+
def load_autosplit(self, **kwargs):
121+
raise ValueError("load_autosplit not supported for vision model")
122+
def load_autosplit_gen(self, **kwargs):
123+
raise ValueError("load_autosplit not supported for vision model")
103124

104125

105126
def process(
@@ -134,7 +155,7 @@ def process(
134155
# Onward
135156

136157
n_device = module.device_idx
137-
if n_device is not None and n_device != device and n_device >= 0:
158+
if idx == 0 or (n_device is not None and n_device != device and n_device >= 0):
138159
hidden_states = safe_move_tensor(hidden_states, n_device, non_blocking = True)
139160

140161
if cos.device != hidden_states.device:
@@ -149,4 +170,62 @@ def process(
149170
}
150171
)
151172

152-
return hidden_states
173+
return hidden_states
174+
175+
176+
def get_image_embeddings(
177+
self,
178+
model: ExLlamaV2,
179+
tokenizer: ExLlamaV2Tokenizer,
180+
image: Image,
181+
text_alias: str,
182+
) -> ExLlamaV2MMEmbedding:
183+
"""
184+
:param model:
185+
Text model for which to produce embeddings
186+
187+
:param tokenizer:
188+
Tokenizer
189+
190+
:param image:
191+
Input PIL image
192+
193+
:param text_alias:
194+
Text string to represent this embedding for tokenizing
195+
196+
:return:
197+
ExLlamaV2MMEmbedding
198+
"""
199+
200+
width, height = image.size
201+
original_size = (height, width)
202+
203+
image_tensor = self.preprocess_func(self.config, image)
204+
image_size = tuple(image_tensor.shape[1:])
205+
206+
embedding_tensor = self.process(image_tensor)
207+
208+
features_y = image_size[0] // 16
209+
features_x = image_size[1] // 16
210+
211+
embedding_tensor = self.postprocess_func(
212+
model,
213+
tokenizer,
214+
embedding_tensor[0],
215+
features_y,
216+
features_x,
217+
)
218+
219+
mme = ExLlamaV2MMEmbedding(
220+
model = model,
221+
embeddings = embedding_tensor,
222+
text_alias = text_alias
223+
)
224+
225+
mme.metadata.update({
226+
"original_size": original_size,
227+
"preprocessed_size": image_size,
228+
"patches_size": (features_y, features_x),
229+
})
230+
231+
return mme

experimental/multimodal_pixtral_hf.py

Lines changed: 24 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,52 +8,29 @@
88
ExLlamaV2Config,
99
ExLlamaV2Cache,
1010
ExLlamaV2Tokenizer,
11-
ExLlamaV2MultimodalProjector,
12-
ExLlamaV2VisionTower
11+
ExLlamaV2VisionTower,
1312
)
1413

1514
from exllamav2.generator import (
1615
ExLlamaV2DynamicGenerator,
1716
ExLlamaV2Sampler,
18-
ExLlamaV2MMEmbedding
1917
)
2018

2119
from PIL import Image
2220
import requests
2321

24-
# Get an input image
25-
26-
url = "https://pbs.twimg.com/media/BAeuBsnCIAAUITV.jpg:large"
27-
image = Image.open(requests.get(url, stream = True).raw)
28-
2922
# Unquantized model used for experiment:
3023
#
3124
# https://huggingface.co/mistral-community/pixtral-12b/
3225

3326
model_directory = "/mnt/str/models/pixtral-12b"
3427
config = ExLlamaV2Config(model_directory)
35-
config.max_seq_len = 32768 # default is 1M
36-
37-
# Load multimodal projector
38-
39-
multimodal_projector = ExLlamaV2MultimodalProjector(config)
40-
multimodal_projector.load()
41-
42-
# Load vision tower and preprocessor
43-
44-
vision_tower = ExLlamaV2VisionTower(config)
45-
vision_tower.load(progress = True)
46-
47-
# Preprocess
48-
49-
image_tensor = vision_tower.preprocess(image)
50-
image_tensor = image_tensor.cuda()
51-
image_size = tuple(image_tensor.shape[1:])
28+
config.max_seq_len = 16384 # default is 1M
5229

53-
# Produce embeddings
30+
# Load vision model and multimodal projector and initialize preprocessor
5431

55-
embeddings = vision_tower.process(image_tensor)
56-
embeddings = multimodal_projector.forward(embeddings)[0]
32+
vision_model = ExLlamaV2VisionTower(config)
33+
vision_model.load(progress = True)
5734

5835
# Load EXL2 model
5936

@@ -62,24 +39,6 @@
6239
model.load_autosplit(cache, progress = True)
6340
tokenizer = ExLlamaV2Tokenizer(config)
6441

65-
# Insert [IMG_BREAK] and [IMG_END] tokens.
66-
67-
features_x = image_size[1] // 16
68-
features_y = image_size[0] // 16
69-
assert image_size == (features_y * 16, features_x * 16) # Image should be padded in preprocessing
70-
71-
id_break = tokenizer.single_id("[IMG_BREAK]")
72-
id_end = tokenizer.single_id("[IMG_END]")
73-
img_break = model.modules[0].forward(torch.tensor([id_break], dtype = torch.long)).to("cuda:0")
74-
img_end = model.modules[0].forward(torch.tensor([id_end], dtype = torch.long)).to("cuda:0")
75-
76-
dim = embeddings.shape[-1]
77-
embeddings = embeddings.view((features_y, features_x, dim))
78-
break_col = img_break.expand(features_y, -1, -1)
79-
embeddings = torch.cat((embeddings, break_col), dim = 1)
80-
embeddings = embeddings.view((features_y * (features_x + 1)), dim)
81-
embeddings = torch.cat((embeddings, img_end), dim = 0)
82-
8342
# Create generator
8443

8544
generator = ExLlamaV2DynamicGenerator(
@@ -90,15 +49,24 @@
9049

9150
# Create an MMEmbedding for the image features and a prompt containing the placeholder string
9251

93-
image_tokens_a = ExLlamaV2MMEmbedding(
94-
model = model,
95-
embeddings = embeddings,
96-
text_alias = "{{EMBED_A}}"
97-
)
98-
99-
prompt = "[INST]{{EMBED_A}}\nDescribe the image.[/INST]"
100-
101-
# Pass embeddings to generator
52+
image_embeddings = [
53+
vision_model.get_image_embeddings(
54+
model = model,
55+
tokenizer = tokenizer,
56+
image = img,
57+
text_alias = alias,
58+
)
59+
for (alias, img) in [
60+
("{{IMAGE_1}}", Image.open("test_image_1.jpg")),
61+
("{{IMAGE_2}}", Image.open("test_image_2.jpg")),
62+
]
63+
]
64+
65+
prompt = "[INST]{{IMAGE_1}}{{IMAGE_2}}\n" + \
66+
"What are the similarities and differences between these two experiments?[/INST]"
67+
68+
# Run prompt through generator, with embeddings. The tokenizer will insert preepared image tokens in place
69+
# of the aliases
10270

10371
output = generator.generate(
10472
prompt = prompt,
@@ -108,7 +76,7 @@
10876
decode_special_tokens = True,
10977
stop_conditions = [tokenizer.eos_token_id],
11078
gen_settings = ExLlamaV2Sampler.Settings.greedy(),
111-
embeddings = [image_tokens_a],
79+
embeddings = image_embeddings,
11280
)
11381

11482
print(output)

experimental/test_image_1.jpg

84.2 KB
Loading

experimental/test_image_2.jpg

39.8 KB
Loading

0 commit comments

Comments
 (0)