Skip to content

Commit a2dd4a1

Browse files
[transforms] SpinQuantModifier & QuIPModifier transform_block_size field (#1806)
SUMMARY: Resolves `INFERENG-1882` The research community [has pointed out](https://github.com/IST-DASLab/FP-Quant?tab=readme-ov-file#fp-format-quantization-harness) that the rotation/transform block size is important when performing transforms: > Key to efficiency is that the Hadamard block size matches the microscaling format group size (16 or 32) This exposes a new field on SpinQuantModifier and QuIPModifier to allow the user to set it to an arbitrary value, as long as the model's hidden_size and head_dim are both evenly divisible by it. - [x] Add to SpinQuant Modifier. Option to allow for different `transform_block_size`s for R1 vs. R2 can be added at a future time. - [x] Add to QuIPModifier. Option to allow for different `transform_block_size`s for U vs. V can be added at a future time. Merge in conjunction with: * neuralmagic/compressed-tensors#466 TEST PLAN: `transform_block_size` added to parameterized `tests/llmcompressor/modifiers/transform/(test_correctness.py|test_serialization.py)` --------- Signed-off-by: Brian Dellabetta <[email protected]>
1 parent f53071b commit a2dd4a1

File tree

4 files changed

+48
-14
lines changed

4 files changed

+48
-14
lines changed

src/llmcompressor/modifiers/transform/quip/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ class QuIPModifier(Modifier):
5454
:param learnable: If true, attach gradients to transform weights for training
5555
:param precision: Precision at which all transforms should be applied. This applies
5656
to both weight fusing and online rotations
57+
:param transform_block_size: Block size to use for rotation matrices. The model's
58+
hidden_size must be evenly divisible by transform_block_size.
59+
Layers will be transformed by a block-diagonal matrix where each block is a
60+
matrix of this size.
61+
If None is provided, model's hidden_size will be used
5762
:param ignore: Modules to ignore when attaching transforms
5863
:param transform_config: Optional transform config for overriding provided arguments
5964
""" # noqa: E501
@@ -66,6 +71,7 @@ class QuIPModifier(Modifier):
6671
randomize: bool = Field(default=False)
6772
learnable: bool = Field(default=False)
6873
precision: TorchDtype = Field(default=torch.float64)
74+
transform_block_size: Optional[int] = Field(default=None)
6975
ignore: Union[str, List[str]] = Field(default="lm_head")
7076

7177
# optional override for more fine-grained control
@@ -129,6 +135,7 @@ def _create_config(self) -> TransformConfig:
129135
def _create_v_scheme(self) -> TransformScheme:
130136
return TransformScheme(
131137
type=self.transform_type,
138+
block_size=self.transform_block_size,
132139
apply=[
133140
TransformArgs(
134141
targets=self.targets,
@@ -150,6 +157,7 @@ def _create_v_scheme(self) -> TransformScheme:
150157
def _create_u_scheme(self) -> TransformScheme:
151158
return TransformScheme(
152159
type=self.transform_type,
160+
block_size=self.transform_block_size,
153161
apply=[
154162
TransformArgs(
155163
targets=self.targets,

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ class SpinQuantModifier(Modifier, use_enum_values=True):
6969
:param learnable: if True, attach gradients to transform weights for training
7070
:param precision: Precision at which all transforms should be applied. This applies
7171
to both weight fusing and online rotations
72+
:param transform_block_size: Block size to use for rotation matrices. The model's
73+
hidden_size and head_dim must be evenly divisible by transform_block_size.
74+
Layers will be transformed by a block-diagonal matrix where each block is a
75+
matrix of this size.
76+
If None is provided, model's hidden_size will be used for R1, R3, and R4
77+
and model's head_dim will be used for R2
7278
:param mappings: Specifies layers within a model to target for transforms.
7379
A mapping will be inferred if None is provided
7480
:param norm_mappings: Specifies layers within a model to target for norm fusing.
@@ -83,6 +89,7 @@ class SpinQuantModifier(Modifier, use_enum_values=True):
8389
randomize: bool = Field(default=False)
8490
learnable: bool = Field(default=False)
8591
precision: TorchDtype = Field(default=torch.float64)
92+
transform_block_size: Optional[int] = Field(default=None)
8693

8794
# norm mappings separate from spinquant mappings to allow users to
8895
# override spinquant mappings with transform_config without overriding norms
@@ -186,6 +193,7 @@ def _create_r1_scheme(self) -> TransformScheme:
186193
randomize=self.randomize,
187194
requires_grad=self.learnable,
188195
precision=self.precision,
196+
block_size=self.transform_block_size,
189197
apply=[
190198
TransformArgs(
191199
targets=[
@@ -219,12 +227,20 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
219227
else:
220228
raise NotImplementedError()
221229

230+
if self.transform_block_size:
231+
if head_dim % self.transform_block_size != 0:
232+
raise ValueError(
233+
f"transform_block_size {self.transform_block_size} must be set "
234+
f"such that model's head_dim {head_dim} is evenly divisible by it"
235+
)
236+
head_dim = self.transform_block_size
237+
222238
return TransformScheme(
223239
type=self.transform_type,
224240
randomize=self.randomize,
225241
requires_grad=self.learnable,
226242
precision=self.precision,
227-
head_dim=head_dim,
243+
block_size=head_dim,
228244
apply=[
229245
TransformArgs(targets=[self.mappings.attn_v], location="weight_output"),
230246
TransformArgs(
@@ -246,6 +262,7 @@ def _create_r4_scheme(self) -> TransformScheme:
246262
randomize=self.randomize,
247263
requires_grad=self.learnable,
248264
precision=self.precision,
265+
block_size=self.transform_block_size,
249266
apply=[
250267
TransformArgs(
251268
targets=[*self.mappings.mlp_out],

tests/llmcompressor/modifiers/transform/test_correctness.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,32 @@
1818
reason="Skipping correctness tests requiring gated model access",
1919
)
2020
@pytest.mark.parametrize(
21-
"modifier,model_dtype,precision,exp_mse",
21+
"modifier,model_dtype,precision,transform_block_size,exp_mse",
2222
[
23-
(QuIPModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019
24-
(QuIPModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0022
25-
(QuIPModifier, torch.float32, torch.float32, 5e-10), # 1.0e-10
26-
(QuIPModifier, torch.float32, torch.float64, 5e-11), # 2.7e-11
27-
(SpinQuantModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0030
28-
(SpinQuantModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0029
29-
(SpinQuantModifier, torch.float32, torch.float32, 5e-4), # 4e-4
30-
(SpinQuantModifier, torch.float32, torch.float64, 5e-4), # 4e-4
23+
(QuIPModifier, torch.bfloat16, torch.bfloat16, None, 5e-3), # 0.0019
24+
(QuIPModifier, torch.bfloat16, torch.float32, 16, 5e-3), # 0.0022
25+
(QuIPModifier, torch.float32, torch.float32, 32, 5e-10), # 1.0e-10
26+
(QuIPModifier, torch.float32, torch.float64, 64, 5e-11), # 2.7e-11
27+
(SpinQuantModifier, torch.bfloat16, torch.bfloat16, None, 5e-3), # 0.0030
28+
(SpinQuantModifier, torch.bfloat16, torch.float32, 16, 5e-3), # 0.0029
29+
(SpinQuantModifier, torch.float32, torch.float32, 32, 5e-4), # 4e-4
30+
(SpinQuantModifier, torch.float32, torch.float64, 64, 5e-4), # 4e-4
3131
],
3232
)
33-
def test_apply_correctness(modifier, model_dtype, precision, exp_mse):
33+
def test_apply_correctness(
34+
modifier, model_dtype, precision, transform_block_size, exp_mse
35+
):
3436
model = AutoModelForCausalLM.from_pretrained(
3537
"meta-llama/Llama-3.2-1B-Instruct", device_map="cuda", torch_dtype=model_dtype
3638
)
3739
untie_word_embeddings(model)
3840

3941
state = State(model=model)
40-
modifier = modifier(transform_type="random-hadamard", precision=precision)
42+
modifier = modifier(
43+
transform_type="random-hadamard",
44+
precision=precision,
45+
transform_block_size=transform_block_size,
46+
)
4147

4248
input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()}
4349
with torch.no_grad():

tests/llmcompressor/modifiers/transform/test_serialization.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55

66
@pytest.mark.parametrize("modifier", [SpinQuantModifier, QuIPModifier])
7-
def test_reload(modifier):
8-
instance = modifier(transform_type="hadamard")
7+
@pytest.mark.parametrize("transform_block_size", [16, 32])
8+
def test_reload(modifier, transform_block_size):
9+
instance = modifier(
10+
transform_type="hadamard", transform_block_size=transform_block_size
11+
)
912
dump = instance.model_dump()
1013
assert modifier.model_validate(dump) == instance

0 commit comments

Comments
 (0)