Skip to content

Commit 0dc048d

Browse files
authored
Merge branch 'main' into fix-qparams-decompression
2 parents a86c657 + 1ec8bb6 commit 0dc048d

File tree

2 files changed

+317
-47
lines changed

2 files changed

+317
-47
lines changed

src/compressed_tensors/utils/match.py

Lines changed: 164 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import os
1617
import re
18+
from collections import defaultdict
1719
from collections.abc import Generator
1820
from typing import Iterable, List, Mapping, Optional, Tuple, Union
1921

@@ -29,6 +31,7 @@
2931
"match_named_parameters",
3032
"match_targets",
3133
"match_modules_set",
34+
"get_lowest_common_ancestor_name",
3235
"is_match",
3336
"is_narrow_match",
3437
]
@@ -157,68 +160,194 @@ def match_targets(
157160
return matched_targets
158161

159162

163+
def get_lowest_common_ancestor_name(names: list[str | None]) -> str:
164+
"""
165+
Given a list of names, returns the lowest-scope common name ignoring Nones.
166+
167+
Implementation is a small alteration of os.path.commonprefix
168+
https://docs.python.org/3/library/os.path.html#os.path.commonprefix
169+
170+
([s1, s2]->prefix->result)
171+
# case 0: multiple modules: [abc.a., abc.b.] -> .abc. -> abc
172+
# case 1: single module: [abc.] -> .abc. -> abc
173+
# case 2: substring modules: [abc., ab.] -> .ab -> ""
174+
# case 3: parent & child: [ab., ab.a.] -> .ab. -> ab
175+
"""
176+
names = [name for name in names if name is not None]
177+
if len(names) == 0:
178+
return ""
179+
180+
# 1) find longest shared prefix
181+
s1 = "." + min(names) + "."
182+
s2 = "." + max(names) + "."
183+
common_prefix = os.path.commonprefix([s1, s2])
184+
# 2) throw away right most dot and name fragment, throw away leftmost char
185+
# ".keep.thro" -> "keep", "." -> ""
186+
return common_prefix[1 : common_prefix.rfind(".")]
187+
188+
160189
def match_modules_set(
161190
model: torch.nn.Module,
162191
targets: Optional[Iterable[str]],
163192
ignore: Optional[Iterable[str]] = None,
164-
) -> Generator[Iterable[torch.nn.Module]]:
193+
error_on_module_rematch: bool = True,
194+
) -> Generator[List[List[torch.nn.Module]]]:
165195
"""
166-
Yields modules grouped with the same order and size as `targets`.
167-
Values are returned in order of `model.named_modules()`
196+
Yields modules grouped by parent context.
197+
198+
We group by parent context so that we can return ALL matches of a
199+
specific target that can be paired with another target. This is most
200+
relevant in the case of MoE modules with multiple modules for each
201+
expert i.e. post_attention_layernorm <-> mlp.expert.N.gate_proj,
202+
mlp.expert.N.up_proj for all N. The parent context will differ from
203+
one layer to another while being the same for one expert to another.
168204
169-
E.g. the following targets would yield module belonging to the following layers:
205+
Each returned group is a list (of lists) with the same size
206+
and order as `targets` while all matches for each target and
207+
the overall order of the groups are ordered in the same way
208+
as `model.named_modules`
209+
210+
211+
E.g. the following targets would yield modules belonging to the following layers:
170212
```python3
171213
match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == (
172-
(
173-
`model.layers.0.self_attn.q_proj`,
174-
`model.layers.0.self_attn.k_proj`,
175-
`model.layers.0.self_attn.v_proj`,
176-
),
177-
(
178-
`model.layers.1.self_attn.q_proj`,
179-
`model.layers.1.self_attn.k_proj`,
180-
`model.layers.1.self_attn.v_proj`,
181-
),
214+
[
215+
[`layers.0.self_attn.q_proj`],
216+
[`layers.0.self_attn.k_proj`],
217+
[`layers.0.self_attn.v_proj`],
218+
],
219+
[
220+
[`layers.1.self_attn.q_proj`],
221+
[`layers.1.self_attn.k_proj`],
222+
[`layers.1.self_attn.v_proj`],
223+
],
182224
...
183-
(
184-
`model.layers.32.self_attn.q_proj`,
185-
`model.layers.32.self_attn.k_proj`,
186-
`model.layers.32.self_attn.v_proj`,
187-
),
188225
)
189226
```
190227
191228
This can be used to match layers to their corresponding downstream counterparts.
192229
For example, matching layer norms to their subsequent linear layers
193230
```python3
194231
for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)):
195-
fuse_norm_linears(norm, [q, k, v])
232+
fuse_norm_linears(*norm, [*q, *k, *v])
233+
```
234+
235+
Alternatively for MoE you would get multiple matches
236+
per target per group, E.g.
237+
238+
```python3
239+
240+
targets = [
241+
"post_attention_layernorm",
242+
"up_proj",
243+
"down_proj"
244+
]
245+
match_modules_set(model, targets) == (
246+
[
247+
[layers.0.post_attention_layernorm],
248+
[
249+
`layers.0.mlp.experts.0.up_proj`,
250+
`layers.0.mlp.experts.1.up_proj`,
251+
...
252+
],
253+
[
254+
`layers.0.mlp.experts.0.down_proj`,
255+
`layers.0.mlp.experts.1.down_proj`,
256+
...
257+
258+
]
259+
], # <- first yield
260+
[
261+
[layers.1.post_attention_layernorm],
262+
[
263+
`layers.1.mlp.experts.0.up_proj`,
264+
`layers.1.mlp.experts.1.up_proj`,
265+
...
266+
],
267+
[
268+
`layers.1.mlp.experts.0.down_proj`,
269+
`layers.1.mlp.experts.1.down_proj`,
270+
...
271+
]
272+
],
273+
...
274+
)
275+
```
196276
197277
:param model: model containing modules to match against
198278
:param targets: target strings, potentially containing "re:" prefixes
199279
:param ignore: targets to ignore, potentially containing "re:" prefixes
280+
:param error_on_module_rematch: if True, errors when a module gets
281+
matched to multiple targets, if False, no error. (Defaults to True)
200282
"""
201283
targets = targets or []
202284
ignore = ignore or []
203285

204-
matches = dict.fromkeys(targets, None)
286+
# as we iterate through modules and try to match them with targets,
287+
# the algorithm can be in 2 possible states:
288+
# 0) unmatched_targets > 0, i.e. some of the targets haven't been matched.
289+
# Keep matching until all targets have at least one match
290+
# 1) unmatched_targets == 0 i.e. we have at least one match for each target.
291+
# At this point we are unsure if we have a full set or if we need to add
292+
# more matches.
293+
# There are 3 things that can happen once were in state 1:
294+
# A) found a new match with same parent_context,
295+
# (add it to matches and keep going)
296+
# B) found a new match with different parent_context, i.e. we found a match
297+
# that requires a deeper parent context, this indicates that this match
298+
# should be part of a new set.
299+
# (yield current set [not including newest match] and go back to state 0)
300+
# C) ran out of modules, we will always yield the final remaining set when
301+
# we we've iterated through all the modules in the model.
302+
# (yield final set then exit.)
303+
# Note: its possible to iterate through all the modules in the model while
304+
# not having a full matched set if the user specified a bad matching, in
305+
# that case something has gone wrong and we error
306+
matches = defaultdict(list)
307+
parent_context = None
308+
unmatched_targets = set(targets)
309+
205310
for name, module in model.named_modules():
206-
# match until we get a full set
311+
matched_targets_for_cur_module = set()
207312
for target in targets:
208313
if is_match(name, module, target, ignore):
209-
if matches[target] is not None:
210-
raise ValueError(f"Matched a {target} twice before completing set")
211-
matches[target] = module
212-
213-
# once we have a full set, yield and reset
214-
if targets and all((matches[target] is not None for target in targets)):
215-
yield [matches[target] for target in targets] # ensure correct ordering
216-
matches = dict.fromkeys(targets, None)
217-
218-
# check that none are left over
219-
unmatched_keys = [match for match, value in matches.items() if value is not None]
220-
if len(unmatched_keys):
221-
raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
314+
new_parent_context = get_lowest_common_ancestor_name(
315+
[name, parent_context]
316+
)
317+
318+
# code for (B)
319+
if not unmatched_targets and new_parent_context != parent_context:
320+
yield [matches[target] for target in targets]
321+
matches = defaultdict(list)
322+
new_parent_context = name
323+
unmatched_targets = set(targets)
324+
325+
matches[target].append(module)
326+
parent_context = new_parent_context
327+
unmatched_targets -= {target}
328+
matched_targets_for_cur_module |= {target}
329+
330+
if len(matched_targets_for_cur_module) > 1 and error_on_module_rematch:
331+
raise ValueError(
332+
f"module: {name} was matched with multiple targets: "
333+
f"{matched_targets_for_cur_module} which is unexpected "
334+
"disable this check by setting `error_on_module_rematch = False`"
335+
)
336+
337+
# never found anything
338+
if unmatched_targets == set(targets):
339+
return
340+
341+
# code for (C)
342+
if not unmatched_targets: # have a full matching
343+
yield [matches[target] for target in targets]
344+
return
345+
346+
raise ValueError(
347+
f"Found a final incomplete set with matches found for keys: "
348+
f"{set(targets) - unmatched_targets} "
349+
f"but no matches found for keys: {unmatched_targets}"
350+
)
222351

223352

224353
def is_match(

0 commit comments

Comments
 (0)