Skip to content

Commit b2dd5a7

Browse files
authored
Modify handling for Pixtral Large model params (#701)
* Modify handling for Pixtral Large model params. * Fix multimodal_projector_bias to default to True if not in model config.json
1 parent cf7fcd1 commit b2dd5a7

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

exllamav2/architecture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ class Params:
312312
})
313313
self.mmp.mlp_gate = False
314314
self.mmp.mlp_act_func = "gelu"
315-
self.mmp.mlp_bias = True
315+
self.mmp.mlp_bias = bool(read_config.get("multimodal_projector_bias", True))
316316

317317
# Yi
318318

exllamav2/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,10 @@ def check_keys(archparams, prefix):
476476
self.vision_num_attention_heads = read(read_config, int, ["vision_config->num_attention_heads"], no_default)
477477
self.vision_num_key_value_heads = read(read_config, int, ["vision_config->num_key_value_heads"], self.vision_num_attention_heads)
478478
self.vision_num_key_value_groups = self.vision_num_attention_heads // self.vision_num_key_value_heads
479+
self.multimodal_projector_bias = read(read_config, bool, ["multimodal_projector_bias"], True)
479480

480481
self.vision_hidden_act = read(read_config, str, ["vision_config->hidden_act"], no_default)
481-
self.vision_hidden_size = read(read_config, int, ["vision_config->image_size"], no_default)
482+
self.vision_hidden_size = read(read_config, int, ["vision_config->hidden_size"], 1024)
482483
patch_size = read(read_config, int, ["vision_config->patch_size"], no_default)
483484
self.vision_rope_theta = read(read_config, int, ["vision_config->rope_theta"], no_default)
484485
self.vision_feature_layer = read(read_config, int, ["vision_feature_layer"], no_default)

exllamav2/vlm/mmprojector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
in_features = cfg.vision_hidden_size,
2828
out_features = cfg.hidden_size,
2929
interm_features = cfg.hidden_size,
30+
has_bias=cfg.multimodal_projector_bias,
3031
has_norm = False,
3132
has_residual = False,
3233
)

0 commit comments

Comments
 (0)