Skip to content

Commit bbcdc8c

Browse files
committed
implement fuse_norm_linears
Signed-off-by: Kyle Sayers <[email protected]>
1 parent f28a9d5 commit bbcdc8c

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

src/llmcompressor/modeling/fuse.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Iterable
2+
3+
import torch
4+
from compressed_tensors import update_offload_parameter
5+
6+
__all__ = ["fuse_norm_linears"]
7+
8+
9+
def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]):
10+
"""
11+
Fuse a norm layer into subsequent linear layers. This useful for ensuring transform
12+
invariance between norm and linear layers.
13+
14+
Note that a model cannot be properly trained after its norms have been fused
15+
16+
:param norm: norm layer whose weight will be fused into subsequent linears
17+
:param linears: linear layers which directly follow the norm layer
18+
"""
19+
if isinstance(norm, torch.nn.RMSNorm):
20+
for linear in linears:
21+
# spinquant does this op in float64
22+
new_weight = linear.weight * norm.weight
23+
update_offload_parameter(linear, "weight", new_weight)
24+
25+
update_offload_parameter(norm, "weight", torch.ones_like(norm.weight))
26+
27+
else:
28+
raise ValueError(f"Cannot fuse norm of type {type(norm)}")

0 commit comments

Comments
 (0)