Skip to content

Commit 9244003

Browse files
committed
Add support for Mistral 3.1 VLM
1 parent 68f7461 commit 9244003

File tree

6 files changed

+91
-13
lines changed

6 files changed

+91
-13
lines changed

examples/multimodal.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
# Pixtral:
2727
# https://huggingface.co/mistral-community/pixtral-12b/
2828
# https://huggingface.co/turboderp/pixtral-12b-exl2
29+
# Mistral-Small 3.1:
30+
# https://huggingface.co/prince-canuma/Mistral-Small-3.1-24B-Instruct-2503
2931
# Qwen2-VL:
3032
# https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
3133
# https://huggingface.co/turboderp/Qwen2-VL-7B-Instruct-exl2
@@ -34,18 +36,21 @@
3436
# https://huggingface.co/turboderp/gemma-3-27b-it-exl2
3537

3638
# mode = "pixtral"
39+
mode = "mistral3"
3740
# mode = "qwen2"
38-
mode = "gemma3"
41+
# mode = "gemma3"
3942

4043
streaming = True
4144
greedy = True
4245

4346
if mode == "pixtral":
4447
model_directory = "/mnt/str/models/pixtral-12b-exl2/6.0bpw"
4548
elif mode == "qwen2":
46-
model_directory = "/mnt/str/models/qwen2.5-vl-7b-instruct-exl2/6.0bpw"
49+
model_directory = "/mnt/str/models/qwen2.5-vl-7b-instruct-exl2/5.0bpw"
4750
elif mode == "gemma3":
48-
model_directory = "/mnt/str/models/gemma3-27b-it-exl2/5.0bpw"
51+
model_directory = "/mnt/str/models/gemma3-12b-it-exl2/6.0bpw"
52+
elif mode == "mistral3":
53+
model_directory = "/mnt/str/models/mistral-small-3.1-24b-instruct/exl2/4.5bpw"
4954

5055
images = [
5156
# {"file": "media/test_image_1.jpg"},
@@ -62,7 +67,7 @@
6267
# Initialize model
6368

6469
config = ExLlamaV2Config(model_directory)
65-
config.max_seq_len = 16384 # Pixtral default is 1M
70+
config.max_seq_len = 8192 # Pixtral default is 1M
6671

6772
# Load vision model and multimodal projector and initialize preprocessor
6873

@@ -72,8 +77,8 @@
7277
# Load EXL2 model
7378

7479
model = ExLlamaV2(config)
75-
cache = ExLlamaV2Cache(model, lazy = True, max_seq_len = 16384)
76-
model.load_autosplit(cache, progress = True)
80+
cache = ExLlamaV2Cache(model, max_seq_len = 8192, lazy = True)
81+
model.load_autosplit(progress = True, cache = cache)
7782
tokenizer = ExLlamaV2Tokenizer(config)
7883

7984
# Create generator
@@ -121,7 +126,7 @@ def get_image(file = None, url = None):
121126
# Image token IDs are assigned sequentially, however, so two ExLlamaV2Embedding objects created from the same
122127
# source image will not be recognized as the same image for purposes of prompt caching etc.
123128

124-
if mode == "pixtral":
129+
if mode in ["pixtral", "mistral3"]:
125130
prompt = (
126131
"[INST]" +
127132
placeholders +

exllamav2/architecture.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ class Params:
241241
rope_freq_half: bool = False
242242
learned_emb: bool = False
243243
output_norm: bool = False
244+
mlp_merger: bool = False
245+
mlp_patch_merger: bool = False
244246

245247
# Component models
246248
self.lm_prefix = ""
@@ -340,6 +342,52 @@ class Params:
340342
self.mmp.mlp_act_func = "gelu"
341343
self.mmp.mlp_bias = bool(read_config.get("multimodal_projector_bias", True))
342344

345+
# Mistral 3 multimodal
346+
347+
if (
348+
arch_string == "Mistral3ForConditionalGeneration" and
349+
"vision_config" in read_config and
350+
read_config["vision_config"].get("model_type") == "pixtral"
351+
):
352+
arch_recognized = True
353+
self.lm_prefix = "language_model."
354+
self.lm.layer_keys += \
355+
layer_keys_llama_norms + \
356+
layer_keys_llama_attn + \
357+
layer_keys_llama_mlp
358+
self.lm.expect_keys += \
359+
expect_keys_llama
360+
361+
self.vt_prefix = "vision_tower."
362+
self.vt.keys.update({
363+
"attn_q": ".attention.q_proj",
364+
"attn_k": ".attention.k_proj",
365+
"attn_v": ".attention.v_proj",
366+
"attn_o": ".attention.o_proj",
367+
"mlp_gate": ".feed_forward.gate_proj",
368+
"mlp_up": ".feed_forward.up_proj",
369+
"mlp_down": ".feed_forward.down_proj",
370+
"norm_1": ".attention_norm",
371+
"norm_2": ".ffn_norm",
372+
"layers": "transformer.layers",
373+
"ln_pre": "ln_pre",
374+
})
375+
self.vt.mlp_merger = True
376+
self.vt.mlp_patch_merger = True
377+
378+
self.mmp_prefix = "multi_modal_projector."
379+
self.mmp.keys.update({
380+
"norm_2": "norm",
381+
"mlp_gate": None,
382+
"mlp_up": "linear_1",
383+
"mlp_down": "linear_2",
384+
"patch_merger": "patch_merger.merging_layer",
385+
})
386+
self.mmp.mlp_patch_merger = True
387+
self.mmp.mlp_gate = False
388+
self.mmp.mlp_act_func = "gelu"
389+
self.mmp.mlp_bias = bool(read_config.get("multimodal_projector_bias", True))
390+
343391
# Yi
344392

345393
if arch_string == "YiForCausalLM":

exllamav2/config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,18 +552,20 @@ def check_keys(archparams, prefix):
552552
self.vision_merger_intermediate_size = self.vision_intermediate_size
553553

554554
image_processor_type = read(read_prep_config, str, ["image_processor_type"], no_default)
555-
assert image_processor_type == "PixtralImageProcessor", \
555+
assert image_processor_type == "PixtralImageProcessor" or image_processor_type == "PixtralImageProcessorFast", \
556556
f"Wrong image processor type: {image_processor_type}"
557557
self.vision_image_mean = read(read_prep_config, list, ["image_mean"], no_default)
558558
self.vision_image_std = read(read_prep_config, list, ["image_std"], no_default)
559-
self.vision_patch_size = read(read_prep_config, dict, ["patch_size"], no_default)
559+
self.vision_patch_size = read(read_prep_config, object, ["patch_size"], no_default)
560+
if isinstance(self.vision_patch_size, int):
561+
self.vision_patch_size = {"width": self.vision_patch_size, "height": self.vision_patch_size}
560562
assert all(self.vision_patch_size.get(x) == patch_size for x in ["width", "height"]), \
561563
"Patch size inconsistency between config.json and preprocessor_config.json"
562564
self.vision_resample = read(read_prep_config, int, ["resample"], no_default)
563565
self.vision_rescale_factor = read(read_prep_config, float, ["rescale_factor"], no_default)
564566
self.vision_size = read(read_prep_config, dict, ["size"], no_default)
565567
self.vision_num_channels = 3
566-
self.vision_spatial_merge_size = 1
568+
self.vision_spatial_merge_size = read(read_config, int, ["spatial_merge_size"], 1)
567569
self.vision_max_size = 16384
568570
self.vision_window_size = None
569571

exllamav2/mlp.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ def __init__(
115115
else:
116116
self.gate_proj = None
117117

118+
if merge and ap.mlp_patch_merger:
119+
self.patch_merger_proj = ExLlamaV2Linear(model, key + km["patch_merger"], in_features * merge**2, in_features, ap.mlp_bias)
120+
self.submodules += [self.patch_merger_proj]
121+
else:
122+
self.patch_merger_proj = None
118123

119124
def numel(self) -> int:
120125

@@ -158,6 +163,9 @@ def load(
158163
if self.gate_proj is not None: self.gate_proj.load(device_context = device_context, output_map = down_map)
159164
self.up_proj.load(device_context = device_context, output_map = down_map)
160165

166+
if self.patch_merger_proj is not None:
167+
self.patch_merger_proj.load()
168+
161169
if self.up_proj.is_quant():
162170
assert self.gate_proj is None or self.gate_proj.is_quant()
163171
assert self.up_proj.is_quant(), "Partially quantized MLP layer"
@@ -302,6 +310,8 @@ def set_device_idx(self, idx: int | None):
302310
if self.gate_proj is not None: self.gate_proj.set_device_idx(idx)
303311
self.up_proj.set_device_idx(idx)
304312
self.down_proj.set_device_idx(idx)
313+
if self.patch_merger_proj is not None:
314+
self.patch_merger_proj.set_device_idx(idx)
305315

306316

307317
# @profile
@@ -458,9 +468,18 @@ def forward_torch(
458468
if self.pre_layernorm else hidden_states
459469

460470
if self.merge:
461-
bd = post_norm.shape[:-2]
462-
l, d = post_norm.shape[-2:]
463-
post_norm = post_norm.view(*bd, l // self.merge, d * self.merge)
471+
if self.archparams.mlp_patch_merger:
472+
bsz = hidden_states.shape[0]
473+
assert bsz == 1
474+
(h, w), d = kwargs["patch_size"], hidden_states.shape[-1]
475+
image_grid = post_norm.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
476+
grid = F.unfold(image_grid, kernel_size = int(self.merge ** 0.5), stride = int(self.merge ** 0.5))
477+
grid = grid.view(bsz, d * self.merge, -1).transpose(1, 2)
478+
post_norm = self.patch_merger_proj.forward(grid)
479+
else:
480+
bd = post_norm.shape[:-2]
481+
l, d = post_norm.shape[-2:]
482+
post_norm = post_norm.view(*bd, l // self.merge, d * self.merge)
464483

465484
if self.gate_proj is not None:
466485
gate = self.gate_proj.forward(post_norm, loras = loras)

exllamav2/vlm/processor/pixtral.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def postprocess(
5757
Insert [IMG_BREAK] and [IMG_END] tokens in image feature embeddings
5858
"""
5959

60+
features_x //= model.config.vision_spatial_merge_size
61+
features_y //= model.config.vision_spatial_merge_size
62+
6063
assert embeddings.shape[0] == features_y * features_x, \
6164
"Invalid shape for embeddings"
6265

exllamav2/vlm/vision_tower.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def process(
309309
attn_params = attn_params,
310310
**kwargs | ({
311311
"alt_rope_embedding": (cos, sin),
312+
"patch_size": (p_height, p_width),
312313
} if cos is not None else {})
313314
)
314315

0 commit comments

Comments
 (0)