Skip to content

Commit cec2914

Browse files
committed
use match_modules_set
Signed-off-by: Kyle Sayers <[email protected]>
1 parent f18d0e8 commit cec2914

File tree

1 file changed

+5
-19
lines changed
  • src/llmcompressor/modifiers/transform/spinquant

1 file changed

+5
-19
lines changed

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22
from typing import Iterable, List, Literal, Optional
33

4-
from compressed_tensors import is_match, match_named_modules
4+
from compressed_tensors import match_modules_set, match_named_modules
55
from compressed_tensors.transform import (
66
TransformArgs,
77
TransformConfig,
@@ -156,24 +156,10 @@ def _prenormalize_embeddings(self, model: PreTrainedModel):
156156

157157
def _fuse_norms(self, model: PreTrainedModel):
158158
for mapping in self.norm_mappings:
159-
targets = (mapping.norm, *mapping.linears)
160-
matches = dict()
161-
162-
for name, module in model.named_modules():
163-
# match until we get a full set
164-
for target in targets:
165-
if is_match(name, module, target):
166-
if target in matches:
167-
raise ValueError("Cannot match twice")
168-
matches[target] = module
169-
170-
# once we have a full set, fuse and reset
171-
if all(target in matches for target in targets):
172-
fuse_norm_linears(
173-
matches[mapping.norm],
174-
(matches[target] for target in mapping.linears),
175-
)
176-
matches = dict()
159+
for norm, *linears in match_modules_set(
160+
model, (mapping.norm, *mapping.linears)
161+
):
162+
fuse_norm_linears(norm, linears)
177163

178164
def _create_r1_scheme(self) -> TransformScheme:
179165
return TransformScheme(

0 commit comments

Comments
 (0)