diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 6e533cc1a..6b93d2b62 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -14,6 +14,7 @@ from loguru import logger from pydantic import ConfigDict, PrivateAttr, model_validator from torch.nn import Module +from operator import attrgetter from tqdm import tqdm from llmcompressor.core import Event, EventType, State @@ -29,7 +30,7 @@ from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context -from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers +from compressed_tensors import match_named_modules __all__ = ["AWQModifier"] @@ -304,8 +305,8 @@ def _set_resolved_mappings(self, model: Module) -> None: """ resolved_mappings: list[ResolvedMapping] = [] for mapping_idx, mapping in enumerate(self.mappings): - smooth_layers = get_layers( - mapping.smooth_layer, model, exclude_internal_modules=True + smooth_layers = match_named_modules( + model, [mapping.smooth_layer] ) smooth_names = [ smooth_name @@ -323,12 +324,12 @@ def _set_resolved_mappings(self, model: Module) -> None: smooth_layer = smooth_layers[smooth_name] smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) - smooth_parent = get_layer_by_name(smooth_parent_name, model) + smooth_parent = attrgetter(smooth_parent_name)(model) if smooth_parent_name else model balance_layers, balance_names = [], [] for balance_regex in mapping.balance_layers: # find the submodules that match the activation layer - for balance_suffix, balance_layer in get_layers( + for balance_suffix, balance_layer in match_named_modules( balance_regex, smooth_parent, exclude_internal_modules=True, @@ -765,7 +766,7 @@ def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Mod while True: if parent_name == "": return "", module - parent = get_layer_by_name(parent_name, module) + parent = attrgetter(parent_name)(module) if not isinstance(parent, torch.nn.ModuleList): return parent_name, parent parent_name = ".".join(parent_name.split(".")[:-1]) diff --git a/src/llmcompressor/modifiers/distillation/output/base.py b/src/llmcompressor/modifiers/distillation/output/base.py index 130e2470c..f433510aa 100644 --- a/src/llmcompressor/modifiers/distillation/output/base.py +++ b/src/llmcompressor/modifiers/distillation/output/base.py @@ -11,7 +11,7 @@ ) from llmcompressor.utils.fsdp.context import summon_full_params_context from llmcompressor.utils.fsdp.helpers import maybe_get_wrapped, set_wrapped_model -from llmcompressor.utils.pytorch.module import get_layers, set_layer +from compressed_tensors import match_named_modules __all__ = ["OutputDistillationModifier"] @@ -61,8 +61,8 @@ def on_initialize(self, state: State, **kwargs) -> bool: else: model_target, teacher_target = target, target - model_layers = get_layers(model_target, state.model) - teacher_layers = get_layers(teacher_target, state.teacher_model) + model_layers = match_named_modules(model_target, state.model) + teacher_layers = match_named_modules(teacher_target, state.teacher_model) if len(model_layers) < 1: raise ValueError(f"no model layers found for target {target}") @@ -85,8 +85,8 @@ def on_initialize(self, state: State, **kwargs) -> bool: with summon_full_params_context(state.teacher_model, offload_to_cpu=True): for key, (student_wrapper, teacher_wrapper) in self.wrappers_.items(): - set_layer(key, student_wrapper, state.model) - set_layer(key, teacher_wrapper, state.teacher_model) + Module.set_submodule(key, student_wrapper, state.model) + Module.set_submodule(key, teacher_wrapper, state.teacher_model) self.wrapped_kd_model_ = self._create_model_wrapper( student_model=maybe_get_wrapped(state.model), @@ -109,8 +109,8 @@ def on_finalize(self, state: State, **kwargs) -> bool: with summon_full_params_context(state.teacher_model, offload_to_cpu=True): for key, (student_wrapper, teacher_wrapper) in self.wrappers_.items(): - set_layer(key, student_wrapper.layer, state.model) - set_layer(key, teacher_wrapper.layer, state.teacher_model) + Module.set_submodule(key, student_wrapper.layer, state.model) + Module.set_submodule(key, teacher_wrapper.layer, state.teacher_model) del student_wrapper del teacher_wrapper diff --git a/src/llmcompressor/modifiers/obcq/sgpt_base.py b/src/llmcompressor/modifiers/obcq/sgpt_base.py index ce41273f3..a64fef902 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_base.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_base.py @@ -13,12 +13,13 @@ from llmcompressor.modifiers.modifier import Modifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.utils.pytorch.module import ( - get_layers, get_no_split_params, - get_prunable_layers, - match_targets, ) +from compressed_tensors import match_named_modules +def get_prunable_targets(): + """Return the list of prunable layer types.""" + return ["Linear", "Conv1d", "Conv2d", "Conv3d", "QATLinear", "QATConv2d", "QATConv3d", "Conv1D"] class SparsityModifierBase(Modifier): """ @@ -114,8 +115,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # infer module and sequential targets self.sequential_targets = self._infer_sequential_targets(model) - layers = get_layers(self.sequential_targets, model) - self._target_layers = get_layers( + layers = match_named_modules(self.sequential_targets, model) + self._target_layers = match_named_modules( self.targets, model ) # layers containing targets @@ -149,11 +150,11 @@ def on_start(self, state: State, event: Event, **kwargs): layer_sparsity = self.sparsity[index] else: layer_sparsity = self.sparsity - - for name, module in get_prunable_layers(layer).items(): + prunable_targets = get_prunable_targets() + for name, module in match_named_modules(layer, prunable_targets).items(): name = f"{layer_name}.{name}" - if match_targets(name, self.ignore)[0]: + if match_named_modules(name, self.ignore)[0]: continue # HACK: previously, embeddings were not quantized because they were not @@ -210,7 +211,8 @@ def _infer_owl_layer_sparsity( groups = {} for name, layer in layers.items(): - prunable_layers = get_prunable_layers(layer) + prunable_targets = get_prunable_targets() + prunable_layers = match_named_modules(layer, prunable_targets) z = [ m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0) for n, m in prunable_layers.items() diff --git a/src/llmcompressor/modifiers/pruning/constant/base.py b/src/llmcompressor/modifiers/pruning/constant/base.py index 929ee5a5d..d75f6ccea 100644 --- a/src/llmcompressor/modifiers/pruning/constant/base.py +++ b/src/llmcompressor/modifiers/pruning/constant/base.py @@ -8,8 +8,7 @@ LayerParamMasking, param_mask_name, ) -from llmcompressor.utils.pytorch.module import get_layers_params - +from compressed_tensors import match_named_parameters __all__ = ["ConstantPruningModifier"] @@ -29,7 +28,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: if not state.model: return False - self.parameterized_layers_ = get_layers_params(self.targets, state.model) + self.parameterized_layers_ = match_named_parameters(self.targets, state.model) for layer_param_name, parameterized_layer in self.parameterized_layers_.items(): self.add_mask( diff --git a/src/llmcompressor/modifiers/pruning/magnitude/base.py b/src/llmcompressor/modifiers/pruning/magnitude/base.py index 1a218d0e3..328822deb 100644 --- a/src/llmcompressor/modifiers/pruning/magnitude/base.py +++ b/src/llmcompressor/modifiers/pruning/magnitude/base.py @@ -16,8 +16,7 @@ PruningMaskCreatorArgs, PruningMaskFactory, ) -from llmcompressor.utils.pytorch.module import get_layers_params - +from compressed_tensors import match_named_parameters __all__ = ["MagnitudePruningModifier"] @@ -73,7 +72,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.mask_structure ) - self.parameterized_layers_ = get_layers_params(state.model) + self.parameterized_layers_ = match_named_parameters(state.model) for layer_param_name, parameterized_layer in self.parameterized_layers_.items(): self.add_mask( diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index c2b4a4ce3..330014678 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -14,12 +14,8 @@ handle_mapping_resolution_errors, ) from llmcompressor.utils.fsdp.helpers import get_fsdp_parent -from llmcompressor.utils.pytorch.module import ( - get_layers, - get_matching_layer, - match_targets, -) - +from compressed_tensors import match_modules_set +from compressed_tensors import match_named_modules MINIMUM_SMOOTHING_SCALE = 1e-5 @@ -204,13 +200,13 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: """ resolved_mappings = [] for to_balance, to_smooth in self.mappings: - to_smooth_layers = get_layers(to_smooth, model) + to_smooth_layers = match_named_modules(to_smooth, model) for layer_name, smooth_layer in to_smooth_layers.items(): - if not match_targets(layer_name, self.ignore)[0]: + if not match_named_modules(layer_name, self.ignore)[0]: balance_layers = [] for balance_suffix in to_balance: # find the submodule that matches the activation layer - _, balance_layer = get_matching_layer( + _, balance_layer = match_modules_set( balance_suffix, layer_name, model ) if balance_layer: diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 51734ed41..4e418ae77 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -12,10 +12,10 @@ from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.layer_sequential.helpers import ( capture_first_layer_intermediates, - match_modules, maybe_inject_pos_embeddings, to_next_layer_kwargs, ) +from compressed_tensors import match_named_modules from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pipelines.sequential.helpers import ( dispatch_for_sequential, @@ -67,7 +67,7 @@ def __call__( # find layers modifiers = session.lifecycle.recipe.modifiers sequential_targets = get_sequential_targets(modifiers, model, dataset_args) - layers = match_modules(model, sequential_targets) + layers = match_named_modules(model, sequential_targets) LifecycleCallbacks.calibration_epoch_start() diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index d2e4988ee..9a88fc9e7 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -24,6 +24,7 @@ from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.utils.helpers import calibration_forward_context, patch_attr from llmcompressor.utils.pytorch.module import get_no_split_params +from compressed_tensors import match_named_modules from .ast_helpers import autowrap_forwards @@ -100,7 +101,7 @@ def trace_subgraphs( :return: a list of Subgraphs in order of execution """ # find modules - targets = match_modules(model, sequential_targets) + targets = match_named_modules(model, sequential_targets) ancestors = get_sequential_ancestors(model, targets) offloaded = set(m for m in model.modules() if has_offloaded_params(m)) diff --git a/src/llmcompressor/transformers/compression/helpers.py b/src/llmcompressor/transformers/compression/helpers.py index 2bc6b63f1..6b0dadf7b 100644 --- a/src/llmcompressor/transformers/compression/helpers.py +++ b/src/llmcompressor/transformers/compression/helpers.py @@ -8,8 +8,8 @@ from tqdm import tqdm from llmcompressor.modifiers import Modifier -from llmcompressor.pytorch.utils import get_linear_layers from llmcompressor.pytorch.utils.helpers import tensor_sparsity +from compressed_tensors import match_named_modules __ALL__ = [ "tensor_follows_mask_structure", @@ -76,8 +76,8 @@ def infer_sparsity_structure_from_model(model: torch.nn.Module) -> Optional[str] # check for the common sparsity structures structures = {"2:4"} for sparsity_structure in structures: - linear_modules = get_linear_layers(model) offloaded_params = get_state_dict_offloaded_model(model) + linear_modules = match_named_modules(model, ["Linear"]) linear_modules_with_sparsity_structure = [ tensor_follows_mask_structure(offloaded_params[f"{name}.weight"]) diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 72144cc0f..c5a185489 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -19,6 +19,7 @@ fix_fsdp_module_name, summon_full_params_context, ) +from compressed_tensors import match_named_modules try: quant_err = None @@ -46,22 +47,9 @@ __all__ = [ - "match_targets", - "get_default_params", - "match_layers_params", - "get_layers", - "get_layer", - "set_layer", - "get_params", - "get_param", - "get_terminal_layers", - "get_prunable_layers", - "get_quantizable_layers", "qat_active", - "get_layers_params", "get_matching_layer", "get_no_split_params", - "get_layer_by_name", ] ALL_TARGET = "__ALL__" @@ -69,21 +57,6 @@ ALL_QUANTIZABLE_TARGET = "__ALL_QUANTIZABLE__" -def match_targets(name: str, targets: Union[str, List[str]]) -> Tuple[bool, int]: - if isinstance(targets, str): - targets = [targets] - - for index, target in enumerate(targets): - if target[:3] == "re:": - pattern = target[3:] - if re.match(pattern, name): - return True, index - elif name == target: - return True, index - - return False, -1 - - def match_class(layer: Module, targets: Union[str, List[str]]) -> Tuple[bool, int]: if isinstance(targets, str): targets = [targets] @@ -95,186 +68,6 @@ def match_class(layer: Module, targets: Union[str, List[str]]) -> Tuple[bool, in return False, -1 -def get_default_params(layers: Dict[str, Module]) -> Dict[str, Parameter]: - params = {} - for name, layer in layers.items(): - for param_name, param in layer.named_parameters(): - if param_name == "weight": - params[name] = param - break - return params - - -def match_layers_params( - targets: Union[str, List[str]], module: Module, params: bool = False -) -> Dict[str, Union[Module, Parameter]]: - if targets == ALL_TARGET: - values = get_terminal_layers(module) - - return values if not params else get_default_params(values) - - if targets == ALL_PRUNABLE_TARGET: - values = get_prunable_layers(module) - - return values if not params else get_default_params(values) - - if targets == ALL_QUANTIZABLE_TARGET: - values = get_quantizable_layers(module) - - return values if not params else get_default_params(values) - - if isinstance(targets, str): - targets = [targets] - - resolved = {} - targets_found = [False for _ in range(len(targets))] - - for name, layer in module.named_modules(): - # due to nesting, FSDP may not be the top layer - name = fix_fsdp_module_name(name) - match, match_index = match_targets(name, targets) - if match and not params: - targets_found[match_index] = True - resolved[name] = layer - else: - match, match_index = match_class(layer, targets) - if match: - targets_found[match_index] = True - resolved[name] = layer - - for param_name, param in layer.named_parameters(): - if "." in param_name: # skip parameters of nested layers - continue - - param_match, param_match_index = match_targets( - f"{name}.{param_name}", targets - ) - if param_match: - targets_found[param_match_index] = True - resolved[f"{name}"] = layer if not params else param - - missed = [target for found, target in zip(targets_found, targets) if not found] - if len(missed) > 0: - raise ValueError(f"Could not find targets {missed} in module {module}") - - return resolved - - -def get_layers( - targets: Union[str, List[str]], - module: Module, - exclude_internal_modules: bool = False, -) -> Dict[str, Module]: - """ - Get layers (also known as submodules) of module based on targets - - :param targets: names or regexes to search for - Can be regex, e.g. "re:.*input_layernorm$" to find all layers - in module whose names end in string "input_layernorm" - :param module: Parent module in which to search for targets - :param exclude_internal_modules: If True, don't include internal - modules added by llm-compressor, e.g. Observers and Transforms. - Defaults to False to maintain backward compatibility - - :return: dict of {layer name -> module} of all layers in module - that match targets - """ - layer_dict = match_layers_params(targets, module) - if exclude_internal_modules: - layer_dict = { - name: layer - for name, layer in layer_dict.items() - if not isinstance(layer, InternalModule) - } - - return layer_dict - - -def get_layer(target: str, module: Module) -> Tuple[str, Module]: - layers = get_layers(target, module) - if len(layers) != 1: - raise ValueError(f"Expected 1 layer for target {target}, found {len(layers)}") - name, layer = next(iter(layers.items())) - - return name, layer - - -def set_layer(target: str, layer: Module, module: Module) -> Module: - with summon_full_params_context(module): - # importing here to avoid circular import - from llmcompressor.utils.fsdp.helpers import maybe_get_wrapped - - parent_target = ".".join(target.split(".")[:-1]) - if parent_target != "": - parent_layer = get_layer(parent_target, module)[1] - else: - parent_layer = maybe_get_wrapped(module) - old_layer = getattr(parent_layer, target.split(".")[-1]) - setattr(parent_layer, target.split(".")[-1], layer) - - return old_layer - - -def get_params(targets: Union[str, List[str]], module: Module) -> Dict[str, Parameter]: - return match_layers_params(targets, module, params=True) - - -def get_param(target: str, module: Module) -> Tuple[str, Parameter]: - params = get_params(target, module) - if len(params) != 1: - raise ValueError( - f"Expected 1 parameter for target {target}, found {len(params)}" - ) - name, param = next(iter(params.items())) - - return name, param - - -def get_terminal_layers(module: Module) -> Dict[str, Module]: - terminal = {} - - for name, layer in module.named_modules(): - if len(list(layer.named_modules())) > 1: - continue - - terminal[name] = layer - - return terminal - - -def get_prunable_layers(module: Module) -> Dict[str, Module]: - prunable = {} - - for name, layer in module.named_modules(): - if ( - isinstance(layer, Linear) - or isinstance(layer, _ConvNd) - or (QATLinear and isinstance(layer, QATLinear)) - or (QATConv2d and isinstance(layer, QATConv2d)) - or (QATConv3d and isinstance(layer, QATConv3d)) - or (TransformerConv1D and isinstance(layer, TransformerConv1D)) - ): - prunable[name] = layer - - return prunable - - -def get_quantizable_layers(module: Module) -> Dict[str, Module]: - if QATLinear is None: - raise ImportError( - "PyTorch version is not setup for Quantization. " - "Please install a QAT compatible version of PyTorch" - ) - - quantizable = {} - - for name, layer in module.named_modules(): - if isinstance(layer, Linear) or isinstance(layer, _ConvNd): - quantizable[name] = layer - - return quantizable - - def qat_active(module: Module) -> bool: """ Determines if any layers in the model have quantization enabled by checking for @@ -292,22 +85,6 @@ def qat_active(module: Module) -> bool: return False -def get_layers_params( - targets: Union[str, List[str]], module: Module -) -> Dict[str, ModelParameterizedLayer]: - params = get_params(targets, module) - layers = get_layers(targets, module) - - parameterized_layers = {} - for name, param in params.items(): - param_layer = ModelParameterizedLayer( - layer_name=name, layer=layers[name], param_name=name, param=param - ) - parameterized_layers[name] = param_layer - - return parameterized_layers - - def get_matching_layer( target: str, name_to_match: str, module: Module ) -> Optional[Tuple[str, Module]]: @@ -323,7 +100,7 @@ def get_matching_layer( :return: Tuple containing the layer name and module that fits the target regex and best matches name_to_match, or None if no match can be found """ - potential_matches = get_layers(target, module) + potential_matches = match_named_modules(target, module) largest_substring = 0 match = None for name, module in potential_matches.items(): diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_owl.py b/tests/llmcompressor/transformers/obcq/test_obcq_owl.py index a1de2533d..9fd8e732a 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_owl.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_owl.py @@ -6,8 +6,7 @@ from llmcompressor.core.session_functions import create_session from llmcompressor.datasets import format_calibration_data from llmcompressor.modifiers.obcq import SparseGPTModifier -from llmcompressor.utils.pytorch.module import get_layers - +from compressed_tensors import match_named_modules @pytest.mark.integration def test_infer_owl_layer_sparsity(): @@ -29,7 +28,7 @@ def test_infer_owl_layer_sparsity(): dataloader = format_calibration_data(dataset) sequential_targets = modifier._infer_sequential_targets(model) - layers = get_layers(sequential_targets, model) + layers = match_named_modules(sequential_targets, model) sparsities = modifier._infer_owl_layer_sparsity(model, layers, dataloader) assert sparsities.keys() == layers.keys() diff --git a/tests/llmcompressor/transformers/tracing/test_models.py b/tests/llmcompressor/transformers/tracing/test_models.py index 6fd25d9a7..b6e3c14ef 100644 --- a/tests/llmcompressor/transformers/tracing/test_models.py +++ b/tests/llmcompressor/transformers/tracing/test_models.py @@ -14,9 +14,9 @@ WhisperForConditionalGeneration, ) -from llmcompressor.pipelines.sequential.helpers import match_modules from llmcompressor.transformers.tracing.debug import trace from llmcompressor.utils.pytorch.module import get_no_split_params +from compressed_tensors import match_named_modules @pytest.mark.skipif( @@ -148,7 +148,7 @@ def get_target_modules(model, sequential_targets): if isinstance(sequential_targets, str): sequential_targets = [sequential_targets] - return match_modules(model, sequential_targets) + return match_named_modules(model, sequential_targets) def run_subgraphs(model, subgraphs, inputs): diff --git a/tests/llmcompressor/utils/pytorch/test_module.py b/tests/llmcompressor/utils/pytorch/test_module.py index 1ab40aa15..83ae0c802 100644 --- a/tests/llmcompressor/utils/pytorch/test_module.py +++ b/tests/llmcompressor/utils/pytorch/test_module.py @@ -1,9 +1,6 @@ import pytest import torch.nn as nn -from llmcompressor.utils.pytorch import get_layer_by_name - - @pytest.fixture def example_nested_module() -> str: return nn.Sequential( @@ -14,21 +11,3 @@ def example_nested_module() -> str: ) -@pytest.mark.unit -def test_get_layer_by_name(example_nested_module): - # Test getting the parent of a nested layer - layer = get_layer_by_name("0", example_nested_module) - assert layer == example_nested_module[0] - - layer = get_layer_by_name("1.1", example_nested_module) - assert layer == example_nested_module[1][1] - - layer = get_layer_by_name("2.0", example_nested_module) - assert layer == example_nested_module[2][0] - - layer = get_layer_by_name("2.1", example_nested_module) - assert layer == example_nested_module[2][1] - - # Test getting the parent of a non-existent layer - with pytest.raises(AttributeError): - get_layer_by_name("non_existent_layer", example_nested_module)