Skip to content

Commit bcf1ee3

Browse files
committed
Support GPTQ models with gptq_v2 checkpoint_format
1 parent 07f9775 commit bcf1ee3

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

exllamav2/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class ExLlamaV2Config:
107107
norm_head: int | None
108108

109109
checkpoint_fused_mlp: bool
110+
checkpoint_offset_qzeros: bool
110111

111112

112113
def __init__(self,
@@ -287,6 +288,11 @@ def prepare(self, no_tensors: bool = False):
287288
# if scaling_type == "yarn":
288289
# self.scale_alpha_value = factor
289290

291+
# Checkpoint format (for GPTQ models)
292+
293+
checkpoint_format = read(read_config, str, ["quantization_config->checkpoint_format"], None)
294+
self.checkpoint_offset_qzeros = (checkpoint_format == "gptq_v2")
295+
290296
# Create map of model tensors
291297

292298
if no_tensors: return

exllamav2/ext.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ def make_q_matrix(w: dict,
320320
temp_dq: torch.Tensor,
321321
key: str = None,
322322
prescale: float = 1,
323-
max_dq_rows = 0):
323+
max_dq_rows = 0,
324+
offset_qzeros: bool = False):
324325

325326
# EXL2
326327

@@ -354,6 +355,9 @@ def make_q_matrix(w: dict,
354355
if prescale != 1: w["scales"] *= prescale
355356
if w["scales"].dtype == torch.float: w["scales"] = w["scales"].half()
356357

358+
if offset_qzeros:
359+
w["qzeros"] -= 0b00010001000100010001000100010001
360+
357361
# GPTQ with g_idx (act_order)
358362

359363
if "g_idx" in w and not (w["g_idx"] == 0).all().item():

exllamav2/linear.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,15 @@ def load(self,
9898
w: dict | nn.Parameter | tuple | None = None,
9999
device_tensors: bool = True):
100100

101+
cfg = self.model.config
102+
101103
if self.f_key: w = self.load_weight_fused(self.f_key, self.f_beg, self.f_end, self.in_features, self.out_features, self.altpack_qkv)
102104
if w is None: w = self.load_weight()
103105

104106
# Load quantized linear layer from dictionary
105107

106108
if isinstance(w, dict):
107-
assert not self.model.config.load_in_q4, "Can't load quantized layer in Q4 mode"
109+
assert not cfg.load_in_q4, "Can't load quantized layer in Q4 mode"
108110
if self.has_bias:
109111
assert "bias" in w, self.key + " has no bias but bias expected"
110112
else:
@@ -119,7 +121,8 @@ def load(self,
119121
self.q_handle = ext.make_q_matrix(w,
120122
self.temp_dq,
121123
prescale = self.prescale,
122-
max_dq_rows = self.model.config.max_dq_size // self.out_features)
124+
max_dq_rows = cfg.max_dq_size // self.out_features,
125+
offset_qzeros = cfg.checkpoint_offset_qzeros)
123126
self.prev_prescale = self.prescale
124127
self.prescale = 1
125128

0 commit comments

Comments
 (0)