File tree Expand file tree Collapse file tree 1 file changed +28
-0
lines changed
src/llmcompressor/modeling Expand file tree Collapse file tree 1 file changed +28
-0
lines changed Original file line number Diff line number Diff line change
1
+ from typing import Iterable
2
+
3
+ import torch
4
+ from compressed_tensors import update_offload_parameter
5
+
6
+ __all__ = ["fuse_norm_linears" ]
7
+
8
+
9
+ def fuse_norm_linears (norm : torch .nn .Module , linears : Iterable [torch .nn .Linear ]):
10
+ """
11
+ Fuse a norm layer into subsequent linear layers. This useful for ensuring transform
12
+ invariance between norm and linear layers.
13
+
14
+ Note that a model cannot be properly trained after its norms have been fused
15
+
16
+ :param norm: norm layer whose weight will be fused into subsequent linears
17
+ :param linears: linear layers which directly follow the norm layer
18
+ """
19
+ if isinstance (norm , torch .nn .RMSNorm ):
20
+ for linear in linears :
21
+ # spinquant does this op in float64
22
+ new_weight = linear .weight * norm .weight
23
+ update_offload_parameter (linear , "weight" , new_weight )
24
+
25
+ update_offload_parameter (norm , "weight" , torch .ones_like (norm .weight ))
26
+
27
+ else :
28
+ raise ValueError (f"Cannot fuse norm of type { type (norm )} " )
You can’t perform that action at this time.
0 commit comments