|
1 | 1 | from enum import Enum
|
2 | 2 | from typing import Iterable, List, Literal, Optional
|
3 | 3 |
|
4 |
| -from compressed_tensors import is_match, match_named_modules |
| 4 | +from compressed_tensors import match_modules_set, match_named_modules |
5 | 5 | from compressed_tensors.transform import (
|
6 | 6 | TransformArgs,
|
7 | 7 | TransformConfig,
|
@@ -156,24 +156,10 @@ def _prenormalize_embeddings(self, model: PreTrainedModel):
|
156 | 156 |
|
157 | 157 | def _fuse_norms(self, model: PreTrainedModel):
|
158 | 158 | for mapping in self.norm_mappings:
|
159 |
| - targets = (mapping.norm, *mapping.linears) |
160 |
| - matches = dict() |
161 |
| - |
162 |
| - for name, module in model.named_modules(): |
163 |
| - # match until we get a full set |
164 |
| - for target in targets: |
165 |
| - if is_match(name, module, target): |
166 |
| - if target in matches: |
167 |
| - raise ValueError("Cannot match twice") |
168 |
| - matches[target] = module |
169 |
| - |
170 |
| - # once we have a full set, fuse and reset |
171 |
| - if all(target in matches for target in targets): |
172 |
| - fuse_norm_linears( |
173 |
| - matches[mapping.norm], |
174 |
| - (matches[target] for target in mapping.linears), |
175 |
| - ) |
176 |
| - matches = dict() |
| 159 | + for norm, *linears in match_modules_set( |
| 160 | + model, (mapping.norm, *mapping.linears) |
| 161 | + ): |
| 162 | + fuse_norm_linears(norm, linears) |
177 | 163 |
|
178 | 164 | def _create_r1_scheme(self) -> TransformScheme:
|
179 | 165 | return TransformScheme(
|
|
0 commit comments