Skip to content

Commit 193a6b2

Browse files
committed
Pixtral: Add vision tower and preprocessor
1 parent 9504b51 commit 193a6b2

File tree

6 files changed

+295
-6
lines changed

6 files changed

+295
-6
lines changed

exllamav2/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from exllamav2.version import __version__
22

33
from exllamav2.model import ExLlamaV2
4-
from exllamav2.vlm import ExLlamaV2MultimodalProjector
54
from exllamav2.cache import ExLlamaV2CacheBase
65
from exllamav2.cache import ExLlamaV2Cache
76
from exllamav2.cache import ExLlamaV2Cache_Q4
@@ -15,3 +14,6 @@
1514
from exllamav2.util import SeqTensor
1615
from exllamav2.util import Timer
1716
from exllamav2.module import Intervention
17+
18+
from exllamav2.vlm.mmprojector import ExLlamaV2MultimodalProjector
19+
from exllamav2.vlm.vision_tower import ExLlamaV2VisionTower

exllamav2/vlm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from exllamav2.version import __version__
22

3-
from exllamav2.vlm.mmprojector import ExLlamaV2MultimodalProjector
3+
from exllamav2.vlm.mmprojector import ExLlamaV2MultimodalProjector
4+
from exllamav2.vlm.vision_tower import ExLlamaV2VisionTower

exllamav2/vlm/mmprojector.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11

2-
import torch.nn as nn
3-
import torch.nn.functional as F
4-
52
from exllamav2 import ExLlamaV2
63
from exllamav2.config import ExLlamaV2Config
74
from exllamav2.module import ExLlamaV2Module
85
from exllamav2.mlp import ExLlamaV2MLP
9-
from typing import Callable
106

117
class ExLlamaV2MultimodalProjector(ExLlamaV2):
128

139
config: ExLlamaV2Config
1410
modules: list[ExLlamaV2Module]
1511

12+
# noinspection PyMissingConstructor
1613
def __init__(
1714
self,
1815
config: ExLlamaV2Config
@@ -35,8 +32,18 @@ def __init__(
3532
)
3633
]
3734

35+
# noinspection PyMethodOverriding
3836
def forward(self, x):
3937

4038
for m in self.modules:
4139
x = m.forward(x)
4240
return x
41+
42+
def load_tp(self, **kwargs):
43+
raise ValueError("load_tp not supported for multimodal projector")
44+
def load_tp_gen(self, **kwargs):
45+
raise ValueError("load_tp not supported for multimodal projector")
46+
def load_autosplit(self, **kwargs):
47+
raise ValueError("load_autosplit not supported for multimodal projector")
48+
def load_autosplit_gen(self, **kwargs):
49+
raise ValueError("load_autosplit not supported for multimodal projector")

exllamav2/vlm/preprocessor/pixtral.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import numpy as np
3+
from PIL import Image
4+
from exllamav2.config import ExLlamaV2Config
5+
from exllamav2.vlm.util import (
6+
convert_to_rgb,
7+
size_to_longest_edge_and_patch_size,
8+
normalize_image
9+
)
10+
11+
def preprocess(
12+
config: ExLlamaV2Config,
13+
image: Image
14+
) -> torch.Tensor:
15+
16+
assert "longest_edge" in config.vision_size, \
17+
"preprocessing size must specify longest_edge"
18+
19+
patch_size = tuple(config.vision_patch_size[d] for d in ["height", "width"])
20+
longest_edge = config.vision_size["longest_edge"]
21+
resample = Image.Resampling(config.vision_resample)
22+
image_mean = tuple(config.vision_image_mean)
23+
image_std = tuple(config.vision_image_std)
24+
rescale_factor = config.vision_rescale_factor
25+
26+
# Convert to RGB and resize as necessary
27+
28+
image = convert_to_rgb(image)
29+
old_size = image.size
30+
new_size = size_to_longest_edge_and_patch_size(image.size, (longest_edge, longest_edge), patch_size)
31+
if old_size != new_size:
32+
image = image.resize(new_size, resample = resample)
33+
34+
# Convert to numpy array and normalize
35+
36+
image = np.array(image).astype(np.float32)
37+
image = image * rescale_factor
38+
image = normalize_image(image, image_mean, image_std)
39+
40+
# Convert to tensor, shape (3, resized_height, resized_width)
41+
42+
image = image.transpose(2, 0, 1)
43+
image = torch.from_numpy(image).half()
44+
return image

exllamav2/vlm/util.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torch
2+
import numpy as np
3+
from PIL import Image
4+
from typing import Tuple
5+
6+
def convert_to_rgb(image: Image) -> Image:
7+
"""
8+
Converts an image to RGB format and ensure any transparent regions are converted to white
9+
"""
10+
if image.mode == "RGB":
11+
return image
12+
13+
image = image.convert("RGBA")
14+
15+
new_image = Image.new("RGBA", image.size, "WHITE")
16+
new_image.paste(image, (0, 0), image)
17+
new_image = new_image.convert("RGB")
18+
return new_image
19+
20+
21+
def size_to_longest_edge_and_patch_size(
22+
input_size: tuple,
23+
max_size: tuple,
24+
patch_size: tuple,
25+
) -> tuple:
26+
"""
27+
Compute the output size for resizing an image while maintaining aspect ratio and constraining to a
28+
maximum bounding box while keeping each dimension a multiple of the corresponding patch dimension.
29+
"""
30+
31+
assert all(p % d == 0 for p, d in zip(max_size, patch_size)), \
32+
"max_size must be a multiple of patch_size"
33+
34+
# Reduce to bounding box
35+
36+
ratio = max(input_size[0] / max_size[0], input_size[1] / max_size[1])
37+
if ratio > 1:
38+
output_size = tuple(int(np.ceil(d / ratio)) for d in input_size)
39+
else:
40+
output_size = input_size
41+
42+
# Align size to patch grid
43+
44+
output_size = tuple((((d + p - 1) // p) * p) for d, p in zip(output_size, patch_size))
45+
return output_size
46+
47+
def normalize_image(
48+
image: np.ndarray,
49+
mean: tuple,
50+
std: tuple,
51+
) -> np.ndarray:
52+
"""
53+
Normalizes RGB image in numpy format using the mean and standard deviation specified by `mean` and `std`:
54+
image = (image - mean(image)) / std
55+
"""
56+
57+
assert len(mean) == 3 and len(std) == 3, \
58+
"mean and std arguments must be 3D"
59+
60+
# Upcast image to float32 if it's not already a float type
61+
62+
if not np.issubdtype(image.dtype, np.floating):
63+
image = image.astype(np.float32)
64+
65+
mean = np.array(mean, dtype = image.dtype)
66+
std = np.array(std, dtype = image.dtype)
67+
image = (image - mean) / std
68+
return image
69+
70+
71+
def position_ids_in_meshgrid(
72+
height: int,
73+
width: int,
74+
max_width: int
75+
):
76+
"""
77+
Create flat position IDs tensor for grid of patches: id(row, col) = row * max_width + col
78+
"""
79+
80+
row_indices = torch.arange(height).unsqueeze(1) * max_width
81+
col_indices = torch.arange(width).unsqueeze(0)
82+
ids = row_indices + col_indices
83+
return ids.flatten().unsqueeze(0)

exllamav2/vlm/vision_tower.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
from __future__ import annotations
2+
import os, sys
3+
4+
import threading
5+
6+
import torch
7+
from exllamav2 import ExLlamaV2
8+
from exllamav2.conv2d import ExLlamaV2Conv2D
9+
from exllamav2.rmsnorm import ExLlamaV2RMSNorm
10+
from exllamav2.attn import ExLlamaV2Attention
11+
from exllamav2.mlp import ExLlamaV2MLP
12+
from exllamav2.config import ExLlamaV2Config
13+
from exllamav2.module import ExLlamaV2Module
14+
from exllamav2.vlm.preprocessor import pixtral
15+
from exllamav2.compat import safe_move_tensor
16+
17+
from PIL.Image import Image
18+
from exllamav2.vlm.util import position_ids_in_meshgrid
19+
20+
class ExLlamaV2VisionTower(ExLlamaV2):
21+
22+
config: ExLlamaV2Config
23+
modules: list[ExLlamaV2Module]
24+
25+
# noinspection PyMissingConstructor
26+
def __init__(
27+
self,
28+
config: ExLlamaV2Config
29+
):
30+
self.config = config
31+
cfg = self.config
32+
self.archparams = cfg.arch.vt
33+
self.modules = []
34+
35+
# Preprocessor
36+
37+
if cfg.vision_model_type == "pixtral":
38+
self.preprocessor = pixtral.preprocess
39+
else:
40+
raise ValueError(f"Unknown vision model type: {cfg.vision_model_type}")
41+
42+
# Position embeddings
43+
44+
self.p_maxedge = cfg.vision_size["longest_edge"] // cfg.vision_patch_size["width"]
45+
freqs = 1.0 / (cfg.vision_rope_theta ** (torch.arange(0, cfg.vision_head_dim, 2).float() / cfg.vision_head_dim))
46+
h = torch.arange(self.p_maxedge, device=freqs.device)
47+
w = torch.arange(self.p_maxedge, device=freqs.device)
48+
freqs_h = torch.outer(h, freqs[::2]).float()
49+
freqs_w = torch.outer(w, freqs[1::2]).float()
50+
inv_freq = torch.cat(
51+
[
52+
freqs_h[:, None, :].repeat(1, self.p_maxedge, 1),
53+
freqs_w[None, :, :].repeat(self.p_maxedge, 1, 1),
54+
],
55+
dim=-1,
56+
).reshape(-1, cfg.vision_head_dim // 2)
57+
inv_freq = torch.cat((inv_freq, inv_freq), dim = -1)
58+
59+
self.rope_cos = inv_freq.cos().half()
60+
self.rope_sin = inv_freq.sin().half()
61+
62+
# Patch embeddings
63+
64+
patch_size = tuple(config.vision_patch_size[x] for x in ["height", "width"])
65+
patch_conv = ExLlamaV2Conv2D(
66+
model = self,
67+
key = cfg.arch.vt_prefix + "patch_conv",
68+
in_channels = self.config.vision_num_channels,
69+
out_channels = self.config.vision_hidden_size,
70+
kernel_size = patch_size,
71+
has_bias = self.archparams.patch_conv_bias,
72+
archparams = self.archparams,
73+
)
74+
self.modules += [patch_conv]
75+
76+
# Input norm
77+
78+
norm = ExLlamaV2RMSNorm(
79+
model = self,
80+
key = cfg.arch.vt_prefix + "ln_pre",
81+
archparams = self.archparams,
82+
)
83+
self.modules += [norm]
84+
85+
# Decoder layers
86+
87+
for layer_idx in range(self.config.vision_num_layers):
88+
layer_key = cfg.arch.vt_prefix + f"transformer.layers.{layer_idx}"
89+
attn = ExLlamaV2Attention(self, layer_key, layer_idx, archparams = self.archparams)
90+
mlp = ExLlamaV2MLP(self, layer_key, layer_idx, archparams = self.archparams)
91+
self.modules += [attn, mlp]
92+
93+
94+
def forward(self, **kwargs):
95+
raise NotImplementedError()
96+
97+
98+
def preprocess(self, image: Image) -> torch.Tensor:
99+
"""
100+
Preprocess image and prepare for vision tower
101+
"""
102+
return self.preprocessor(self.config, image)
103+
104+
105+
def process(
106+
self,
107+
hidden_states: torch.Tensor,
108+
abort_event: threading.Event | None = None,
109+
**kwargs
110+
):
111+
cfg = self.config
112+
113+
if len(hidden_states.shape) == 3:
114+
hidden_states = hidden_states.unsqueeze(0)
115+
116+
bsz, channels, height, width = hidden_states.shape
117+
118+
p_height = height // cfg.vision_patch_size["height"]
119+
p_width = width // cfg.vision_patch_size["width"]
120+
position_ids = position_ids_in_meshgrid(p_height, p_width, self.p_maxedge)
121+
122+
cos = self.rope_cos[position_ids]
123+
sin = self.rope_sin[position_ids]
124+
attn_params = ExLlamaV2Attention.Params(non_causal_attn = True)
125+
126+
device = self.modules[0].device_idx
127+
for idx, module in enumerate(self.modules):
128+
129+
# Respect abort signal
130+
131+
if abort_event and abort_event.is_set():
132+
return None, None
133+
134+
# Onward
135+
136+
n_device = module.device_idx
137+
if n_device is not None and n_device != device and n_device >= 0:
138+
hidden_states = safe_move_tensor(hidden_states, n_device, non_blocking = True)
139+
140+
if cos.device != hidden_states.device:
141+
cos = safe_move_tensor(cos, hidden_states.device)
142+
sin = safe_move_tensor(sin, hidden_states.device)
143+
144+
hidden_states = module.forward(
145+
hidden_states,
146+
attn_params = attn_params,
147+
**kwargs | {
148+
"alt_rope_embedding": (cos, sin)
149+
}
150+
)
151+
152+
return hidden_states

0 commit comments

Comments
 (0)