Skip to content

Commit 5392b2b

Browse files
committed
cleanup
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 9e975d3 commit 5392b2b

File tree

5 files changed

+64
-53
lines changed

5 files changed

+64
-53
lines changed

src/llmcompressor/modeling/fuse.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,58 +6,55 @@
66
get_execution_device,
77
update_offload_parameter,
88
)
9-
from transformers.models.llama.modeling_llama import LlamaRMSNorm
109

11-
__all__ = ["normalize_embedding", "fuse_norm_linears"]
10+
__all__ = ["center_embeddings", "fuse_norm_linears"]
1211

1312

1413
PRECISION = torch.float64
1514

1615

17-
def normalize_embedding(embedding: torch.nn.Module):
16+
def center_embeddings(embedding: torch.nn.Module):
1817
"""
19-
Normalize each embedding to have a mean of zero
18+
Shift each embedding to have a mean of zero
2019
2120
:param embedding: embedding module containing embeddings to center
2221
"""
23-
if isinstance(embedding, (torch.nn.Embedding)):
24-
with align_module_device(embedding):
25-
weight_dtype = embedding.weight.dtype
26-
weight = embedding.weight.to(PRECISION)
27-
new_weight = weight - weight.mean(dim=-1, keepdim=True)
28-
new_weight = new_weight.to(weight_dtype)
22+
if not hasattr(embedding, "weight"):
23+
raise ValueError(f"Cannot fuse norm of type {type(embedding)}")
2924

30-
update_offload_parameter(embedding, "weight", new_weight)
25+
with align_module_device(embedding):
26+
weight_dtype = embedding.weight.dtype
27+
weight = embedding.weight.to(PRECISION)
28+
new_weight = weight - weight.mean(dim=-1, keepdim=True)
29+
new_weight = new_weight.to(weight_dtype)
3130

32-
else:
33-
raise ValueError(f"Cannot normalize embedding of type {type(embedding)}")
31+
update_offload_parameter(embedding, "weight", new_weight)
3432

3533

3634
def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]):
3735
"""
38-
Fuse a norm layer into subsequent linear layers. This useful for ensuring transform
39-
invariance between norm and linear layers.
36+
Fuse the scaling operation of norm layer into subsequent linear layers.
37+
This useful for ensuring transform invariance between norm and linear layers.
4038
41-
Note that a model cannot be properly trained after its norms have been fused
39+
Note that unitary transforms (rotation) commute with normalization, but not scaling
4240
4341
:param norm: norm layer whose weight will be fused into subsequent linears
4442
:param linears: linear layers which directly follow the norm layer
4543
"""
46-
if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm, torch.nn.LayerNorm)):
47-
for linear in linears:
48-
# NOTE: spinquant does this op in float64
49-
exec_device = get_execution_device(norm)
50-
with align_module_device(norm, exec_device), align_module_device(
51-
linear, exec_device
52-
):
53-
weight_dtype = linear.weight.dtype
54-
new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
55-
new_weight = new_weight.to(weight_dtype)
56-
57-
update_offload_parameter(linear, "weight", new_weight)
58-
59-
new_norm_weight = torch.ones_like(norm.weight, device="cpu")
60-
update_offload_parameter(norm, "weight", new_norm_weight)
61-
62-
else:
44+
if not hasattr(norm, "weight"):
6345
raise ValueError(f"Cannot fuse norm of type {type(norm)}")
46+
47+
for linear in linears:
48+
# NOTE: spinquant does this op in float64
49+
exec_device = get_execution_device(norm)
50+
with align_module_device(norm, exec_device), align_module_device(
51+
linear, exec_device
52+
):
53+
weight_dtype = linear.weight.dtype
54+
new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
55+
new_weight = new_weight.to(weight_dtype)
56+
57+
update_offload_parameter(linear, "weight", new_weight)
58+
59+
new_norm_weight = torch.ones_like(norm.weight, device="cpu")
60+
update_offload_parameter(norm, "weight", new_norm_weight)

src/llmcompressor/modifiers/transform/quip/base.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from typing import List, Literal, Optional, Union
22

3+
import torch
34
from compressed_tensors.transform import (
45
TransformArgs,
56
TransformConfig,
67
TransformScheme,
78
apply_transform_config,
89
)
10+
from compressed_tensors.utils import TorchDtype
911
from pydantic import Field, ValidationInfo, field_validator
1012

1113
from llmcompressor.core import Event, EventType, State
@@ -36,17 +38,19 @@ class QuIPModifier(Modifier):
3638
`"random-matrix"` has the greatest performance cost, but supports any size
3739
:param randomize: If true, create distinct transforms for each application
3840
:param learnable: If true, attach gradients to transform weights for training
41+
:param precision: Precision at which all transforms should be applied. This applies
42+
to both weight fusing and online rotations
3943
:param ignore: Modules to ignore when attaching transforms
4044
:param transform_config: Optional transform config for overriding provided arguments
4145
"""
4246

4347
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
44-
default="hadamard", exclude=True
48+
default="random-hadamard"
4549
)
46-
randomize: bool = Field(default=False, exclude=True)
47-
learnable: bool = Field(default=False, exclude=True)
48-
precision:
49-
ignore: Union[str, List[str]] = Field(default="lm_head", exclude=True)
50+
randomize: bool = Field(default=False)
51+
learnable: bool = Field(default=False)
52+
precision: TorchDtype = Field(default=torch.float64)
53+
ignore: Union[str, List[str]] = Field(default="lm_head")
5054

5155
# optional override for more fine-grained control
5256
# also included in recipe serialization
@@ -105,21 +109,20 @@ def _create_config(self) -> TransformConfig:
105109
TransformArgs(
106110
targets=["Linear"],
107111
location="weight_input",
108-
# location="input",
109112
inverse=True,
110113
ignore=self.ignore,
111114
),
112115
],
113116
randomize=self.randomize,
114117
requires_grad=self.learnable,
118+
precision=self.precision,
115119
),
116120
"u": TransformScheme(
117121
type=self.transform_type,
118122
apply=[
119123
TransformArgs(
120124
targets=["Linear"],
121125
location="weight_output",
122-
# location="output",
123126
ignore=self.ignore,
124127
),
125128
TransformArgs(
@@ -131,6 +134,7 @@ def _create_config(self) -> TransformConfig:
131134
],
132135
randomize=self.randomize,
133136
requires_grad=self.learnable,
137+
precision=self.precision,
134138
),
135139
}
136140
)

tests/llmcompressor/modeling/test_fuse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import pytest
22
import torch
33

4-
from llmcompressor.modeling.fuse import fuse_norm_linears, normalize_embedding
4+
from llmcompressor.modeling.fuse import center_embeddings, fuse_norm_linears
55

66

77
@pytest.mark.unit
8-
def test_normalize_embedding():
8+
def test_center_embeddings():
99
embedding = torch.nn.Embedding(10, 10)
10-
normalize_embedding(embedding)
10+
center_embeddings(embedding)
1111

1212
assert torch.allclose(
1313
embedding.weight.mean(dim=1), torch.zeros(embedding.num_embeddings), atol=1e-5

tests/llmcompressor/modifiers/transform/test_correctness.py renamed to tests/llmcompressor/modifiers/transform/quip/test_correctness.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
23
import pytest
34
import torch
45
from transformers import AutoModelForCausalLM
@@ -9,23 +10,25 @@
910

1011

1112
@requires_gpu
12-
# @pytest.mark.skipif(
13-
# (not os.getenv("HF_TOKEN")),
14-
# reason="Skipping tracing tests requiring gated model access",
15-
# )
13+
@pytest.mark.skipif(
14+
(not os.getenv("HF_TOKEN")),
15+
reason="Skipping tracing tests requiring gated model access",
16+
)
1617
@pytest.mark.parametrize(
17-
"dtype,exp_mse",
18+
"model_dtype,precision,exp_mse",
1819
[
19-
(torch.bfloat16, 5e-3),
20-
(torch.float32, 5e-11),
20+
(torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019
21+
(torch.bfloat16, torch.float32, 5e-3), # 0.0022
22+
(torch.float32, torch.float32, 5e-10), # 1.0777e-10
23+
(torch.float32, torch.float64, 5e-11), # 2.6632e-11
2124
],
2225
)
23-
def test_apply_correctness(dtype, exp_mse):
26+
def test_apply_correctness(model_dtype, precision, exp_mse):
2427
model = AutoModelForCausalLM.from_pretrained(
25-
"meta-llama/Llama-3.2-1B-Instruct", device_map="cuda", torch_dtype=dtype
28+
"meta-llama/Llama-3.2-1B-Instruct", device_map="cuda", torch_dtype=model_dtype
2629
)
2730
state = State(model=model)
28-
modifier = QuIPModifier(transform_type="random-hadamard")
31+
modifier = QuIPModifier(transform_type="random-hadamard", precision=precision)
2932

3033
input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()}
3134
with torch.no_grad():
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from llmcompressor.modifiers.transform import QuIPModifier
2+
3+
4+
def test_reload():
5+
modifier = QuIPModifier(transform_type="hadamard")
6+
dump = modifier.model_dump()
7+
assert QuIPModifier.model_validate(dump) == modifier

0 commit comments

Comments
 (0)