77from compressed_tensors .utils import (
88 align_modules ,
99 get_execution_device ,
10+ get_lowest_common_ancestor_name ,
1011 match_modules_set ,
1112 match_named_modules ,
1213 update_offload_parameter ,
13- get_lowest_common_ancestor_name ,
1414)
1515from loguru import logger
1616from pydantic import ConfigDict , PrivateAttr , model_validator
1717from torch .nn import Module
18- from tqdm import tqdm
1918from torch .utils ._pytree import tree_flatten
19+ from tqdm import tqdm
20+
2021from llmcompressor .core import Event , EventType , State
2122from llmcompressor .modifiers import Modifier
2223from llmcompressor .modifiers .awq .mappings import (
3031from llmcompressor .pipelines .cache import IntermediatesCache
3132from llmcompressor .utils .fsdp .helpers import get_fsdp_parent
3233from llmcompressor .utils .helpers import calibration_forward_context
33- from llmcompressor .utils .pytorch .module import get_layer_by_name
34+ from llmcompressor .utils .pytorch .module import (
35+ get_layer_by_name ,
36+ get_module_to_name_dict ,
37+ )
3438
3539__all__ = ["AWQModifier" ]
3640
@@ -321,30 +325,20 @@ def _set_resolved_mappings(self, model: Module) -> None:
321325 repeat for model.layer.1 and so on
322326 """
323327 resolved_mappings : list [ResolvedMapping ] = []
324-
325- module_to_name = {}
326- for name , module in model .named_modules ():
327- if module in module_to_name :
328- logger .info (
329- f"Warning, { name } and { module_to_name [module ]} both "
330- "share the same module the same module, "
331- "may have trouble resolving mappings."
332- )
333- module_to_name [module ] = name
334-
328+ module_to_name = get_module_to_name_dict (model )
335329 for mapping in self .mappings :
336330 for smooth_layers , * nested_balance_layers in match_modules_set (
337331 model , (mapping .smooth_layer , * mapping .balance_layers ), self .ignore
338332 ):
339- assert len (smooth_layers )== 1 , (
340- "AWQ mappings need to match a single smoothlayer for each mapping but got "
341- f"{ [module_to_name .get (smooth_layer ) for smooth_layer in smooth_layers ]} "
342- f"when matching { mapping . smooth_layer } "
333+ assert len (smooth_layers ) == 1 , (
334+ "AWQ mappings need to match a single smoothlayer for each "
335+ f"mapping but got { [module_to_name .get (s ) for s in smooth_layers ]} "
336+ f" for mapping: { mapping } "
343337 )
344338 smooth_layer = smooth_layers [0 ]
345339 smooth_name = module_to_name .get (smooth_layer )
346340
347- #[[b00, b01, b02...], [b10, b11, b12,...], ...] v
341+ # [[b00, b01, b02...], [b10, b11, b12,...], ...] v
348342 # [b00, b01, b02, ..., b10, b11, b12, ...]
349343 balance_layers = tree_flatten (nested_balance_layers )[0 ]
350344 balance_names = [
@@ -371,7 +365,9 @@ def _set_resolved_mappings(self, model: Module) -> None:
371365 else :
372366 # for multiple balance layers, find lowest common parent
373367 ancestor_name = get_lowest_common_ancestor_name (balance_names )
374- ancestor_name , ancestor = get_lowest_non_module_list_ancestor (ancestor_name , model )
368+ ancestor_name , ancestor = get_lowest_non_module_list_ancestor (
369+ ancestor_name , model
370+ )
375371
376372 resolved_mappings .append (
377373 ResolvedMapping (
@@ -807,7 +803,7 @@ def _accumulate_mean(
807803
808804def get_lowest_non_module_list_ancestor (name , module : Module ) -> tuple [str , Module ]:
809805 """
810- Given a name and a model, finds lowest ancestor of
806+ Given a name and a model, finds lowest ancestor of
811807 named module that's not a ModuleList
812808 i.e. module_list.module_dict.module_list -> module_list.module_dict
813809 i.e. module_list.module_dict -> module_list.module_dict
0 commit comments