Skip to content

Commit a88ca3c

Browse files
spinquant and quip_online, running but outputting gibberish
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 3207124 commit a88ca3c

File tree

6 files changed

+95
-61
lines changed

6 files changed

+95
-61
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# flake8: noqa
22

3+
from .presets import TRANSFORM_PRESETS
34
from .transform import TransformModifier
4-
from .transform.presets import TRANSFORM_PRESETS
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from .quip import QUIP
1+
from .quip import QUIP, QUIP_ONLINE
22
from .spinquant import LLAMA_SPINQUANT, LLAMA_SPINQUANT_R1R2
33

44
TRANSFORM_PRESETS = {
55
"QUIP": QUIP,
6+
"QUIP_ONLINE": QUIP_ONLINE,
67
"LLAMA_SPINQUANT": LLAMA_SPINQUANT,
78
"LLAMA_SPINQUANT_R1R2": LLAMA_SPINQUANT_R1R2,
89
}

src/llmcompressor/modifiers/transform/presets/quip.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,61 @@
3838
),
3939
}
4040
)
41+
42+
# https://github.com/vllm-project/llm-compressor/blob/b43b27a2f277a5e62be4f8c713b84fd1c7aa116b/weight_transform.py#L24-L105
43+
QUIP_ONLINE = TransformConfig(
44+
config_groups={
45+
"u_transform_q_o_down_proj": TransformScheme(
46+
type="hadamard",
47+
apply=[
48+
TransformArgs(
49+
targets=[
50+
"re:.*.attn.q_proj$",
51+
"re:.*.attn.o_proj$",
52+
"re:.*.mlp.down_proj$",
53+
],
54+
location="weight_input",
55+
)
56+
],
57+
),
58+
"u_transform_k_v_proj": TransformScheme(
59+
type="hadamard",
60+
apply=[
61+
TransformArgs(
62+
targets=["re:.*.attn.k_proj$", "re:.*.attn.v_proj$"],
63+
location="weight_input",
64+
)
65+
],
66+
),
67+
"u_transform_gate_up_proj": TransformScheme(
68+
type="hadamard",
69+
apply=[
70+
TransformArgs(
71+
targets=["re:.*.mlp.gate_proj$", "re:.*.mlp.up_proj$"],
72+
location="weight_input",
73+
)
74+
],
75+
),
76+
"v_transform_linear": TransformScheme(
77+
type="hadamard",
78+
apply=[
79+
TransformArgs(
80+
targets=["Linear"],
81+
location="weight_output",
82+
ignore=["re:.*.mlp.down_proj$", "lm_head"],
83+
inverse=True,
84+
)
85+
],
86+
),
87+
"v_transform_down_proj": TransformScheme(
88+
type="hadamard",
89+
apply=[
90+
TransformArgs(
91+
targets=["re:.*.mlp.down_proj$"],
92+
location="weight_output",
93+
inverse=True,
94+
)
95+
],
96+
),
97+
}
98+
)

src/llmcompressor/modifiers/transform/presets/spinquant.py

Lines changed: 23 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,23 @@
22

33
# Ref: https://arxiv.org/pdf/2405.16406 Fig 1
44

5-
# All rotations
6-
LLAMA_SPINQUANT = TransformConfig(
7-
transform_groups={
5+
# Mergeable rotations R1 and R2 only
6+
LLAMA_SPINQUANT_R1R2 = TransformConfig(
7+
config_groups={
88
"R1": TransformScheme(
99
type="hadamard",
1010
apply=[
1111
TransformArgs(
12-
targets=["embed_tokens", "o_proj", "down_proj"],
12+
targets=["re:.*embed_tokens$", "re:.*o_proj$", "re:.*down_proj$"],
1313
location="weight_output",
1414
),
1515
TransformArgs(
1616
targets=[
17-
"q_proj",
18-
"k_proj",
19-
"v_proj",
20-
"up_proj",
21-
"gate_proj",
17+
"re:.*q_proj$",
18+
"re:.*k_proj$",
19+
"re:.*v_proj$",
20+
"re:.*up_proj$",
21+
"re:.*gate_proj$",
2222
"lm_head",
2323
],
2424
location="weight_input",
@@ -30,23 +30,31 @@
3030
type="hadamard",
3131
apply=[
3232
TransformArgs(
33-
targets=["v_proj"],
33+
targets=["re:.*v_proj$"],
3434
location="weight_output",
3535
),
3636
TransformArgs(
37-
targets=["o_proj"], location="weight_input", inverse=True
37+
targets=["re:.*o_proj$"], location="weight_input", inverse=True
3838
),
3939
],
4040
),
41+
}
42+
)
43+
44+
# All rotations
45+
LLAMA_SPINQUANT = TransformConfig(
46+
config_groups={
47+
"R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"],
48+
"R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"],
4149
"R3": TransformScheme(
4250
type="hadamard",
4351
apply=[
4452
TransformArgs(
45-
targets=["self_attn"],
53+
targets=["re:.*self_attn$"],
4654
location="k_cache",
4755
),
4856
TransformArgs(
49-
targets=["self_attn"],
57+
targets=["re:.*self_attn$"],
5058
location="q_attn",
5159
),
5260
],
@@ -55,51 +63,11 @@
5563
type="hadamard",
5664
apply=[
5765
TransformArgs(
58-
targets=["down_proj"],
66+
targets=["re:.*down_proj$"],
5967
location="input",
6068
),
6169
TransformArgs(
62-
targets=["down_proj"], location="weight_input", inverse=True
63-
),
64-
],
65-
),
66-
}
67-
)
68-
69-
70-
# Mergeable rotations R1 and R2 only
71-
LLAMA_SPINQUANT_R1R2 = TransformConfig(
72-
config_groups={
73-
"R1": TransformScheme(
74-
type="hadamard",
75-
apply=[
76-
TransformArgs(
77-
targets=["embed_tokens", "o_proj", "down_proj"],
78-
location="weight_output",
79-
),
80-
TransformArgs(
81-
targets=[
82-
"q_proj",
83-
"k_proj",
84-
"v_proj",
85-
"up_proj",
86-
"gate_proj",
87-
"lm_head",
88-
],
89-
location="weight_input",
90-
inverse=True,
91-
),
92-
],
93-
),
94-
"R2": TransformScheme(
95-
type="hadamard",
96-
apply=[
97-
TransformArgs(
98-
targets=["v_proj"],
99-
location="weight_output",
100-
),
101-
TransformArgs(
102-
targets=["o_proj"], location="weight_input", inverse=True
70+
targets=["re:.*down_proj$"], location="weight_input", inverse=True
10371
),
10472
],
10573
),

src/llmcompressor/modifiers/transform/transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def validate_model_after(model: "TransformModifier") -> "TransformModifier":
2626
)
2727
model.config = TRANSFORM_PRESETS[model.preset_config]
2828

29+
return model
30+
2931
def on_initialize(self, state: State, **kwargs) -> bool:
3032
apply_transform_config(state.model, self.config)
3133

tests/llmcompressor/modifiers/transform/test_correctness.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@
77

88

99
@pytest.mark.parametrize(
10-
"dtype,exp_max,exp_mse", [
11-
(torch.bfloat16, 1.1, 0.012), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501
12-
(torch.float32, 4e-4, 2e-9)
13-
]
10+
"dtype,exp_max,exp_mse",
11+
[
12+
(
13+
torch.bfloat16,
14+
1.1,
15+
0.012,
16+
), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501
17+
(torch.float32, 4e-4, 2e-9),
18+
],
1419
)
1520
def test_apply_correctness(dtype, exp_max, exp_mse):
1621
model = AutoModelForCausalLM.from_pretrained(

0 commit comments

Comments
 (0)