Skip to content

Commit 9089596

Browse files
committed
Fix quantization for Pixtral, copy vision tower tensors to quantized model
1 parent d37cf7e commit 9089596

File tree

6 files changed

+97
-73
lines changed

6 files changed

+97
-73
lines changed

exllamav2/conversion/compile.py

Lines changed: 77 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -59,86 +59,106 @@ def get_q_module(job, module):
5959
@torch.inference_mode()
6060
def compile_model(job, save_fn, model):
6161

62+
cfg = model.config
6263
out_dict = {}
6364
current_size = 0
6465
file_index = 1
6566
index = 0
6667
shard_bytes = job["shard_size"] * 1024 ** 2
6768

68-
while index < len(model.modules):
69+
extra_tensors = []
70+
if cfg.arch.mmp_prefix:
71+
extra_tensors += [k for k in cfg.tensor_file_map.keys() if k.startswith(cfg.arch.mmp_prefix)]
72+
if cfg.arch.vt_prefix:
73+
extra_tensors += [k for k in cfg.tensor_file_map.keys() if k.startswith(cfg.arch.vt_prefix)]
74+
extra_tensors_size = 0
6975

70-
module = model.modules[index]
76+
while index < len(model.modules) or len(extra_tensors):
7177

72-
if isinstance(module, ExLlamaV2Embedding):
78+
if index < len(model.modules):
7379

74-
d = get_f_module(job, module); out_dict.update(d); current_size += _dsize(d)
80+
module = model.modules[index]
7581

76-
if isinstance(module, ExLlamaV2PosEmbedding):
82+
if isinstance(module, ExLlamaV2Embedding):
7783

78-
d = get_f_module(job, module); out_dict.update(d); current_size += _dsize(d)
84+
d = get_f_module(job, module); out_dict.update(d); current_size += _dsize(d)
7985

80-
if isinstance(module, ExLlamaV2Attention):
86+
if isinstance(module, ExLlamaV2PosEmbedding):
8187

82-
d = get_f_module(job, module.pre_layernorm)
83-
if d: out_dict.update(d); current_size += _dsize(d)
84-
d = get_f_module(job, module.post_layernorm)
85-
if d: out_dict.update(d); current_size += _dsize(d)
86-
d = get_q_module(job, module.q_proj); out_dict.update(d); current_size += _dsize(d)
87-
d = get_q_module(job, module.k_proj); out_dict.update(d); current_size += _dsize(d)
88-
d = get_q_module(job, module.v_proj); out_dict.update(d); current_size += _dsize(d)
89-
d = get_q_module(job, module.o_proj); out_dict.update(d); current_size += _dsize(d)
88+
d = get_f_module(job, module); out_dict.update(d); current_size += _dsize(d)
9089

91-
if isinstance(module, ExLlamaV2MLP):
90+
if isinstance(module, ExLlamaV2Attention):
9291

93-
has_gate = model.config.arch.lm.mlp_gate
94-
d = get_f_module(job, module.pre_layernorm)
95-
if d: out_dict.update(d); current_size += _dsize(d)
96-
d = get_f_module(job, module.post_layernorm)
97-
if d: out_dict.update(d); current_size += _dsize(d)
98-
if has_gate: d = get_q_module(job, module.gate_proj); out_dict.update(d); current_size += _dsize(d)
99-
d = get_q_module(job, module.up_proj); out_dict.update(d); current_size += _dsize(d)
100-
d = get_q_module(job, module.down_proj); out_dict.update(d); current_size += _dsize(d)
92+
d = get_f_module(job, module.pre_layernorm)
93+
if d: out_dict.update(d); current_size += _dsize(d)
94+
d = get_f_module(job, module.post_layernorm)
95+
if d: out_dict.update(d); current_size += _dsize(d)
96+
d = get_q_module(job, module.q_proj); out_dict.update(d); current_size += _dsize(d)
97+
d = get_q_module(job, module.k_proj); out_dict.update(d); current_size += _dsize(d)
98+
d = get_q_module(job, module.v_proj); out_dict.update(d); current_size += _dsize(d)
99+
d = get_q_module(job, module.o_proj); out_dict.update(d); current_size += _dsize(d)
101100

102-
if isinstance(module, ExLlamaV2MoEMLP):
101+
if isinstance(module, ExLlamaV2MLP):
103102

104-
d = get_f_module(job, module.post_attention_layernorm); out_dict.update(d); current_size += _dsize(d)
105-
d = get_f_module(job, module.gate); out_dict.update(d); current_size += _dsize(d)
106-
for i in range(model.config.num_experts):
107-
d = get_q_module(job, module.w1[i]); out_dict.update(d); current_size += _dsize(d)
108-
d = get_q_module(job, module.w3[i]); out_dict.update(d); current_size += _dsize(d)
109-
d = get_q_module(job, module.w2[i]); out_dict.update(d); current_size += _dsize(d)
103+
has_gate = model.config.arch.lm.mlp_gate
104+
d = get_f_module(job, module.pre_layernorm)
105+
if d: out_dict.update(d); current_size += _dsize(d)
106+
d = get_f_module(job, module.post_layernorm)
107+
if d: out_dict.update(d); current_size += _dsize(d)
108+
if has_gate: d = get_q_module(job, module.gate_proj); out_dict.update(d); current_size += _dsize(d)
109+
d = get_q_module(job, module.up_proj); out_dict.update(d); current_size += _dsize(d)
110+
d = get_q_module(job, module.down_proj); out_dict.update(d); current_size += _dsize(d)
110111

111-
if isinstance(module, ExLlamaV2ParallelDecoder):
112+
if isinstance(module, ExLlamaV2MoEMLP):
112113

113-
has_gate = model.config.arch.lm.mlp_gate
114-
has_qk_norm = model.config.use_qk_norm
115-
d = get_f_module(job, module.input_layernorm); out_dict.update(d); current_size += _dsize(d)
116-
d = get_q_module(job, module.attn.q_proj); out_dict.update(d); current_size += _dsize(d)
117-
d = get_q_module(job, module.attn.k_proj); out_dict.update(d); current_size += _dsize(d)
118-
d = get_q_module(job, module.attn.v_proj); out_dict.update(d); current_size += _dsize(d)
119-
d = get_q_module(job, module.attn.o_proj); out_dict.update(d); current_size += _dsize(d)
120-
if has_qk_norm:
121-
d = get_f_module(job, module.attn.q_norm); out_dict.update(d); current_size += _dsize(d)
122-
d = get_f_module(job, module.attn.k_norm); out_dict.update(d); current_size += _dsize(d)
123-
if has_gate:
124-
d = get_q_module(job, module.mlp.gate_proj); out_dict.update(d); current_size += _dsize(d)
125-
d = get_q_module(job, module.mlp.up_proj); out_dict.update(d); current_size += _dsize(d)
126-
d = get_q_module(job, module.mlp.down_proj); out_dict.update(d); current_size += _dsize(d)
114+
d = get_f_module(job, module.post_attention_layernorm); out_dict.update(d); current_size += _dsize(d)
115+
d = get_f_module(job, module.gate); out_dict.update(d); current_size += _dsize(d)
116+
for i in range(model.config.num_experts):
117+
d = get_q_module(job, module.w1[i]); out_dict.update(d); current_size += _dsize(d)
118+
d = get_q_module(job, module.w3[i]); out_dict.update(d); current_size += _dsize(d)
119+
d = get_q_module(job, module.w2[i]); out_dict.update(d); current_size += _dsize(d)
127120

128-
if isinstance(module, ExLlamaV2RMSNorm) or isinstance(module, ExLlamaV2LayerNorm):
121+
if isinstance(module, ExLlamaV2ParallelDecoder):
129122

130-
d = get_f_module(job, module); out_dict.update(d); current_size += _dsize(d)
123+
has_gate = model.config.arch.lm.mlp_gate
124+
has_qk_norm = model.config.use_qk_norm
125+
d = get_f_module(job, module.input_layernorm); out_dict.update(d); current_size += _dsize(d)
126+
d = get_q_module(job, module.attn.q_proj); out_dict.update(d); current_size += _dsize(d)
127+
d = get_q_module(job, module.attn.k_proj); out_dict.update(d); current_size += _dsize(d)
128+
d = get_q_module(job, module.attn.v_proj); out_dict.update(d); current_size += _dsize(d)
129+
d = get_q_module(job, module.attn.o_proj); out_dict.update(d); current_size += _dsize(d)
130+
if has_qk_norm:
131+
d = get_f_module(job, module.attn.q_norm); out_dict.update(d); current_size += _dsize(d)
132+
d = get_f_module(job, module.attn.k_norm); out_dict.update(d); current_size += _dsize(d)
133+
if has_gate:
134+
d = get_q_module(job, module.mlp.gate_proj); out_dict.update(d); current_size += _dsize(d)
135+
d = get_q_module(job, module.mlp.up_proj); out_dict.update(d); current_size += _dsize(d)
136+
d = get_q_module(job, module.mlp.down_proj); out_dict.update(d); current_size += _dsize(d)
131137

132-
if isinstance(module, ExLlamaV2Linear):
138+
if isinstance(module, ExLlamaV2RMSNorm) or isinstance(module, ExLlamaV2LayerNorm):
133139

134-
assert module.key == "lm_head"
135-
d = get_q_module(job, module); out_dict.update(d); current_size += _dsize(d)
140+
d = get_f_module(job, module); out_dict.update(d); current_size += _dsize(d)
136141

137-
index += 1
142+
if isinstance(module, ExLlamaV2Linear):
143+
144+
assert module.key == cfg.arch.lm_prefix + "lm_head"
145+
d = get_q_module(job, module); out_dict.update(d); current_size += _dsize(d)
146+
147+
index += 1
148+
149+
else:
150+
151+
key = extra_tensors[0]
152+
extra_tensors = extra_tensors[1:]
153+
file = cfg.tensor_file_map[key]
154+
with safe_open(file, framework = "pt") as f:
155+
tensor = f.get_tensor(key)
156+
out_dict.update({key: tensor})
157+
extra_tensors_size += _tsize(tensor)
138158

139159
# Save shard
140160

141-
if current_size > shard_bytes or index == len(model.modules):
161+
if current_size > shard_bytes or (index == len(model.modules) and len(extra_tensors) == 0):
142162

143163
print_stage(job, "Compiling", index, len(model.modules))
144164

@@ -175,7 +195,7 @@ def compile_model(job, save_fn, model):
175195

176196
out_dict = dont_save_dict
177197

178-
if index == len(model.modules) and len(out_dict) > 0:
198+
if index == len(model.modules) and len(extra_tensors) == 0 and len(out_dict) > 0:
179199
save_dict = dont_save_dict
180200
dont_save_dict = {}
181201
continue
@@ -203,6 +223,9 @@ def compile_model(job, save_fn, model):
203223
filesize = os.path.getsize(final_filename) // (1024 ** 2)
204224
print(f" -- {final_filename} ({filesize:,} MB)")
205225

226+
if extra_tensors_size:
227+
print(f" -- Tensors copied (MM components): {extra_tensors_size // (1024 ** 2):,} MB")
228+
206229
# Copy all non-tensor files from the model's directory if compiling a full model
207230

208231
if job["compile_full"] is not None:

exllamav2/conversion/convert_exl2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def save_job():
149149
sys.exit()
150150

151151
if job["progress"] == "finished":
152-
print(" !! Job is already finished")
152+
print(f" !! Job is already finished. Clear the working directory, or run this script with -nr/--no_resume to clear it automatically.")
153153
sys.exit()
154154

155155
# Feedback

exllamav2/conversion/optimize.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
def optimize(job, save_fn, model):
99

1010
cfg = model.config
11+
km = cfg.arch.lm.keys
1112

1213
has_gate = cfg.arch.lm.mlp_gate
13-
if has_gate: mlp_key_gate = cfg.arch.mlp_key_gate
14-
mlp_key_up = cfg.arch.mlp_key_up
15-
mlp_key_down = cfg.arch.mlp_key_down
14+
if has_gate: mlp_key_gate = km["mlp_gate"]
15+
mlp_key_up = km["mlp_up"]
16+
mlp_key_down = km["mlp_down"]
1617

1718
norm_interval = (1.5, 3.5)
1819
norm_2ndstage = 0.15
@@ -24,19 +25,19 @@ def optimize(job, save_fn, model):
2425
anneal_stages = 3
2526

2627
first_q_layer = 0
27-
while not model.modules[first_q_layer].key.startswith("model.layers"):
28+
while not model.modules[first_q_layer].key.startswith(cfg.arch.lm_prefix + "model.layers"):
2829
first_q_layer += 1
2930

3031
# max_step_size = 2
3132
# first_layer_bias = 4
3233
# bias_layers = 2
3334
# bias_iter = 0
3435

35-
key = "model.layers.0"
36-
key_q = key + ".self_attn.q_proj"
37-
key_k = key + ".self_attn.k_proj"
38-
key_v = key + ".self_attn.v_proj"
39-
key_o = key + ".self_attn.o_proj"
36+
key = cfg.arch.lm_prefix + "model.layers.0"
37+
key_q = key + km["attn_q"]
38+
key_k = key + km["attn_k"]
39+
key_v = key + km["attn_v"]
40+
key_o = key + km["attn_o"]
4041

4142
if not cfg.arch.lm.is_moe:
4243
if has_gate: key_g = key + mlp_key_gate
@@ -84,11 +85,11 @@ def optimize(job, save_fn, model):
8485

8586
for i in range(num_layers):
8687
if cfg.arch.lm.parallel_decoder_blocks:
87-
m1 = measurement["model.layers." + str(i) + ".parallel_decoder"]["attn"]
88-
m2 = measurement["model.layers." + str(i) + ".parallel_decoder"]["mlp"]
88+
m1 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".parallel_decoder"]["attn"]
89+
m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".parallel_decoder"]["mlp"]
8990
else:
90-
m1 = measurement["model.layers." + str(i) + ".self_attn"]
91-
m2 = measurement["model.layers." + str(i) + "." + mlp_mode]
91+
m1 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + ".self_attn"]
92+
m2 = measurement[cfg.arch.lm_prefix + "model.layers." + str(i) + "." + mlp_mode]
9293
for m in [m1, m2]:
9394
slot = []
9495
param = []
@@ -154,8 +155,8 @@ def optimize(job, save_fn, model):
154155
job["strategy"] = {}
155156
for layer_ in range(num_layers):
156157

157-
k1 = "model.layers." + str(layer_) + ".self_attn"
158-
k2 = "model.layers." + str(layer_) + "." + mlp_mode
158+
k1 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + ".self_attn"
159+
k2 = cfg.arch.lm_prefix + "model.layers." + str(layer_) + "." + mlp_mode
159160
p1 = params[layer_ * 2][solution_idx[layer_ * 2]]
160161
p2 = params[layer_ * 2 + 1][solution_idx[layer_ * 2 + 1]]
161162

exllamav2/conversion/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def quant(job, save_fn, model):
326326

327327
elif isinstance(module, ExLlamaV2Linear):
328328
mode = "linear"
329-
assert module.key == "lm_head"
329+
assert module.key == model.config.arch.lm_prefix + "lm_head"
330330
quantizers["lm_head"] = AdaptiveGPTQ(module.linear)
331331

332332
elif isinstance(module, ExLlamaV2RMSNorm) or isinstance(module, ExLlamaV2LayerNorm):

exllamav2/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def numel(self) -> int:
117117
numel = self.up_proj.numel() + \
118118
self.down_proj.numel()
119119

120-
if self.archparams.arch.mlp_gate:
120+
if self.archparams.mlp_gate:
121121
numel += self.gate_proj.numel()
122122

123123
if self.pre_layernorm is not None:

experimental/multimodal_pixtral_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#
2424
# https://huggingface.co/mistral-community/pixtral-12b/
2525

26-
model_directory = "/mnt/str/models/pixtral-12b"
26+
model_directory = "/mnt/str/models/pixtral-12b-exl2/5.0bpw"
2727
config = ExLlamaV2Config(model_directory)
2828
config.max_seq_len = 16384 # default is 1M
2929

0 commit comments

Comments
 (0)