Skip to content

Commit dc5c30c

Browse files
committed
add r2, increase precision
Signed-off-by: Kyle Sayers <[email protected]>
1 parent f5c2150 commit dc5c30c

File tree

3 files changed

+63
-15
lines changed

3 files changed

+63
-15
lines changed

examples/transform/llama3_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def tokenize(sample):
5959
# Configure the quantization algorithm to run.
6060
# * quantize the weights to 4 bit with GPTQ with a group size 128
6161
recipe = [
62-
# TODO preset_config="LLAMA_SPINQUANT_R1R2" outputs gibberish
6362
# TODO preset_config="QUIP_ONLINE" outputs gibberish
6463
# preset_config="QUIP" output sensible, but cannot load saved
6564
# checkpoint or run evals (~4hrs to run)

src/llmcompressor/modeling/fuse.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear])
2323
# NOTE: spinquant does this op in float64
2424
exec_device = get_execution_device(norm)
2525
with align_module_device(norm, exec_device), align_module_device(linear, exec_device):
26-
new_weight = linear.weight * norm.weight
26+
27+
weight_dtype = linear.weight.dtype
28+
29+
new_weight = linear.weight.to(torch.float64) * norm.weight.to(torch.float64)
30+
31+
new_weight = new_weight.to(weight_dtype)
2732

2833
update_offload_parameter(linear, "weight", new_weight)
2934

src/llmcompressor/modifiers/transform/presets/spinquant.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,69 @@
3636
),
3737
],
3838
),
39-
# "R2": TransformScheme(
40-
# type="hadamard",
41-
# # TODO infer head_dim from config.json in SpinQuantModifier
42-
# head_dim=128,
43-
# apply=[
44-
# TransformArgs(targets=["re:.*v_proj$"], location="weight_output"),
45-
# TransformArgs(
46-
# targets=["re:.*o_proj$"],
47-
# location="weight_input",
48-
# inverse=True,
49-
# ),
50-
# ],
51-
# ),
39+
"R2": TransformScheme(
40+
type="hadamard",
41+
# TODO infer head_dim from config.json in SpinQuantModifier
42+
head_dim=128,
43+
apply=[
44+
TransformArgs(targets=["re:.*v_proj$"], location="weight_output"),
45+
TransformArgs(
46+
targets=["re:.*o_proj$"],
47+
location="weight_input",
48+
inverse=True,
49+
),
50+
],
51+
),
5252
}
5353
)
5454

5555
# All rotations
5656
LLAMA_SPINQUANT = TransformConfig(
5757
config_groups={
58+
"R1": TransformScheme(
59+
type="hadamard",
60+
apply=[
61+
TransformArgs(
62+
targets=[
63+
# outermost rotation
64+
"re:.*embed_tokens$",
65+
# attention rotations
66+
"re:.*o_proj$",
67+
# mlp rotations
68+
"re:.*down_proj$",
69+
],
70+
location="weight_output",
71+
),
72+
TransformArgs(
73+
targets=[
74+
# outermost rotation
75+
"lm_head",
76+
# attention rotations
77+
"re:.*q_proj$",
78+
"re:.*k_proj$",
79+
"re:.*v_proj$",
80+
# mlp rotations
81+
"re:.*up_proj$",
82+
"re:.*gate_proj$",
83+
],
84+
location="weight_input",
85+
inverse=True,
86+
),
87+
],
88+
),
89+
"R2": TransformScheme(
90+
type="hadamard",
91+
# TODO infer head_dim from config.json in SpinQuantModifier
92+
head_dim=128,
93+
apply=[
94+
TransformArgs(targets=["re:.*v_proj$"], location="weight_output"),
95+
TransformArgs(
96+
targets=["re:.*o_proj$"],
97+
location="weight_input",
98+
inverse=True,
99+
),
100+
],
101+
),
58102
# "R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"],
59103
# "R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"],
60104
"R3": TransformScheme(

0 commit comments

Comments
 (0)