@@ -225,7 +225,10 @@ def on_initialize(self, state: State, **kwargs) -> bool:
225225 # (no offloading by default)
226226 self .offload_device = None
227227
228- self ._set_resolved_mappings (state .model )
228+ # Resolve sequential_targets: prefer oneshot() kwarg, fall back to
229+ # modifier field, then auto-detect from model config.
230+ seq_targets = kwargs .get ("sequential_targets" ) or self .sequential_targets
231+ self ._set_resolved_mappings (state .model , seq_targets )
229232
230233 return True
231234
@@ -320,7 +323,11 @@ def on_finalize(self, state: State, **kwargs) -> bool:
320323
321324 return True
322325
323- def _set_resolved_mappings (self , model : Module ) -> None :
326+ def _set_resolved_mappings (
327+ self ,
328+ model : Module ,
329+ sequential_targets : str | list [str ] | None = None ,
330+ ) -> None :
324331 """
325332 Transforms the list of activations to smooth and their corresponding weights
326333 into ResolvedMapping objects, resolving regular expressions.
@@ -331,9 +338,22 @@ def _set_resolved_mappings(self, model: Module) -> None:
331338 weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
332339 would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
333340 repeat for model.layer.1 and so on
341+
342+ :param model: model to resolve mappings for
343+ :param sequential_targets: optional list of module class names that define
344+ the scope for mapping resolution. When provided, only modules that are
345+ children of these targets participate in matching. This prevents vision
346+ encoder modules from polluting the mapping resolution in multimodal models.
334347 """
335348 resolved_mappings : list [ResolvedMapping ] = []
336349 module_to_name = get_module_to_name_dict (model )
350+
351+ # Build a scoped model view when sequential_targets are available.
352+ # This restricts match_modules_set to only consider modules that live
353+ # under a sequential target (e.g. decoder layers), preventing vision
354+ # encoder modules from breaking the parent-context grouping.
355+ match_model = _build_scoped_model (model , sequential_targets )
356+
337357 # Get names of modules targeted for quantization (excludes ignored)
338358 targeted_names = set (
339359 name
@@ -346,7 +366,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
346366 # so that we can handle layers that need smoothing but not quantization
347367 # we only skip if no layers in mapping are targeted for quantization.
348368 for smooth_layers , * nested_balance_layers in match_modules_set (
349- model , (mapping .smooth_layer , * mapping .balance_layers )
369+ match_model , (mapping .smooth_layer , * mapping .balance_layers )
350370 ):
351371 if len (smooth_layers ) > 1 :
352372 raise ValueError (
@@ -1040,3 +1060,67 @@ def _accumulate_mean(
10401060 new_count = prev_count + num_added
10411061
10421062 return (prev_sum + sum_added ) / new_count , new_count
1063+
1064+
1065+ class _ScopedModuleView (Module ):
1066+ """
1067+ A lightweight wrapper that restricts ``named_modules()`` to only yield
1068+ modules whose names fall under given scope prefixes. Everything else
1069+ (attribute access, ``get_submodule``, etc.) is forwarded to the wrapped
1070+ model so that ``match_modules_set`` can resolve module identities normally.
1071+ """
1072+
1073+ def __init__ (self , model : Module , scope_prefixes : set [str ]):
1074+ # bypass Module.__init__ to avoid registering parameters/buffers
1075+ object .__setattr__ (self , "_model" , model )
1076+ object .__setattr__ (self , "_scope_prefixes" , scope_prefixes )
1077+
1078+ def named_modules (self , * args , ** kwargs ):
1079+ for name , mod in self ._model .named_modules (* args , ** kwargs ):
1080+ if not name : # root module — always include
1081+ yield name , mod
1082+ elif any (
1083+ name == p or name .startswith (p + "." ) for p in self ._scope_prefixes
1084+ ):
1085+ yield name , mod
1086+
1087+ def __getattr__ (self , name : str ):
1088+ return getattr (self ._model , name )
1089+
1090+
1091+ def _build_scoped_model (
1092+ model : Module ,
1093+ sequential_targets : str | list [str ] | None ,
1094+ ) -> Module :
1095+ """
1096+ If *sequential_targets* is provided, return a :class:`_ScopedModuleView`
1097+ that only exposes modules living under instances of those target classes.
1098+ Otherwise return *model* unchanged (no-op for text-only models).
1099+ """
1100+ if not sequential_targets :
1101+ return model
1102+
1103+ if isinstance (sequential_targets , str ):
1104+ sequential_targets = [sequential_targets ]
1105+
1106+ target_classes = set (sequential_targets )
1107+
1108+ scope_prefixes : set [str ] = set ()
1109+ for name , mod in model .named_modules ():
1110+ if type (mod ).__name__ in target_classes :
1111+ scope_prefixes .add (name )
1112+
1113+ if not scope_prefixes :
1114+ logger .warning (
1115+ "sequential_targets %s did not match any modules, "
1116+ "falling back to unscoped mapping resolution" ,
1117+ sequential_targets ,
1118+ )
1119+ return model
1120+
1121+ logger .info (
1122+ "Scoping AWQ mapping resolution to %d sequential targets (%s)" ,
1123+ len (scope_prefixes ),
1124+ sequential_targets ,
1125+ )
1126+ return _ScopedModuleView (model , scope_prefixes )
0 commit comments