@@ -76,6 +76,10 @@ class AWQModifier(Modifier, QuantizationMixin):
7676 balance_layers: ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"]
7777 - smooth_layer: "re:.*final_layer_norm"
7878 balance_layers: ["re:.*fc1"]
79+ # activation_hook_target specifies which submodule of the parent to hook
80+ # for activation caching.
81+ # This change is only useful for MoE models with parallel transformer blocks,
82+ # and one should use the default value (None) in most cases.
7983 ignore: ["lm_head"]
8084 config_groups:
8185 group_0:
@@ -122,6 +126,11 @@ class AWQModifier(Modifier, QuantizationMixin):
122126 to smoothed) and the second entry is the layer whose output is scaled to
123127 achieve the smoothing.
124128 If regex is used, it matches layers with the largest overlap in module name.
129+ Each mapping may also include an ``activation_hook_target``: a dotted
130+ attribute path relative to the parent module (lowest common ancestor)
131+ specifying which submodule to hook for activation caching. This is useful
132+ for parallel transformer blocks where the default (hooking
133+ ``balance_layers[0]``) would capture the wrong activations.
125134 :param ignore: list of layers to ignore during quantization (not smoothed).
126135 It should match the name of layers whose outputs are scaled to achieve
127136 smoothing (the second entry of the mappings list).
@@ -389,6 +398,17 @@ def _set_resolved_mappings(self, model: Module) -> None:
389398 balance_names , model , torch .nn .ModuleList
390399 )
391400
401+ activation_hook_target = None
402+ if mapping .activation_hook_target :
403+ activation_hook_target = getattr_chain (
404+ ancestor , mapping .activation_hook_target
405+ )
406+ if activation_hook_target is None :
407+ raise ValueError (
408+ f"activation_hook_target '{ mapping .activation_hook_target } '"
409+ f" not found on parent module '{ ancestor_name } '"
410+ )
411+
392412 resolved_mappings .append (
393413 ResolvedMapping (
394414 smooth_name ,
@@ -397,6 +417,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
397417 balance_names = balance_names ,
398418 parent = ancestor ,
399419 parent_name = ancestor_name ,
420+ activation_hook_target = activation_hook_target ,
400421 )
401422 )
402423 self ._resolved_mappings = resolved_mappings
@@ -468,16 +489,14 @@ def cache_smooth_activations_hook(
468489 # input activations to balance layers needed for loss function
469490 # storing inputs to first balance layer is sufficient
470491 # other balance layers get the same input
471-
472- # The line below is useful for models that use parallel transformer block,
473- # such as gemma 3, command A. Need a better way to integrate it to the code.
474- # layer_to_hook = (
475- # mapping.parent.mlp
476- # if hasattr(mapping.parent, 'mlp')
477- # else mapping.balance_layers[0]
478- # )
492+ #
493+ # For parallel transformer blocks (e.g. Command A, Gemma 3) the first
494+ # balance layer may not receive the right activations. When
495+ # activation_hook_target is set on the mapping, hook that module
496+ # instead of balance_layers[0].
497+ layer_to_hook = mapping .activation_hook_target or mapping .balance_layers [0 ]
479498 self .register_hook (
480- mapping . balance_layers [ 0 ] ,
499+ layer_to_hook ,
481500 create_cache_smooth_activations_hook_fn (mapping .smooth_name ),
482501 "forward" ,
483502 )
0 commit comments