Skip to content

1686 Logic matching refactor #1687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from loguru import logger
from pydantic import ConfigDict, PrivateAttr, model_validator
from torch.nn import Module
from operator import attrgetter
from tqdm import tqdm

from llmcompressor.core import Event, EventType, State
Expand All @@ -29,7 +30,7 @@
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.helpers import calibration_forward_context
from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers
from compressed_tensors import match_named_modules

__all__ = ["AWQModifier"]

Expand Down Expand Up @@ -304,8 +305,8 @@ def _set_resolved_mappings(self, model: Module) -> None:
"""
resolved_mappings: list[ResolvedMapping] = []
for mapping_idx, mapping in enumerate(self.mappings):
smooth_layers = get_layers(
mapping.smooth_layer, model, exclude_internal_modules=True
smooth_layers = match_named_modules(
model, [mapping.smooth_layer]
)
smooth_names = [
smooth_name
Expand All @@ -323,12 +324,12 @@ def _set_resolved_mappings(self, model: Module) -> None:
smooth_layer = smooth_layers[smooth_name]

smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
smooth_parent = get_layer_by_name(smooth_parent_name, model)
smooth_parent = attrgetter(smooth_parent_name)(model) if smooth_parent_name else model

balance_layers, balance_names = [], []
for balance_regex in mapping.balance_layers:
# find the submodules that match the activation layer
for balance_suffix, balance_layer in get_layers(
for balance_suffix, balance_layer in match_named_modules(
balance_regex,
smooth_parent,
exclude_internal_modules=True,
Expand Down Expand Up @@ -765,7 +766,7 @@ def get_lowest_common_parent(names: List[str], module: Module) -> Tuple[str, Mod
while True:
if parent_name == "":
return "", module
parent = get_layer_by_name(parent_name, module)
parent = attrgetter(parent_name)(module)
if not isinstance(parent, torch.nn.ModuleList):
return parent_name, parent
parent_name = ".".join(parent_name.split(".")[:-1])
14 changes: 7 additions & 7 deletions src/llmcompressor/modifiers/distillation/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from llmcompressor.utils.fsdp.context import summon_full_params_context
from llmcompressor.utils.fsdp.helpers import maybe_get_wrapped, set_wrapped_model
from llmcompressor.utils.pytorch.module import get_layers, set_layer
from compressed_tensors import match_named_modules

__all__ = ["OutputDistillationModifier"]

Expand Down Expand Up @@ -61,8 +61,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
else:
model_target, teacher_target = target, target

model_layers = get_layers(model_target, state.model)
teacher_layers = get_layers(teacher_target, state.teacher_model)
model_layers = match_named_modules(model_target, state.model)
teacher_layers = match_named_modules(teacher_target, state.teacher_model)

if len(model_layers) < 1:
raise ValueError(f"no model layers found for target {target}")
Expand All @@ -85,8 +85,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:

with summon_full_params_context(state.teacher_model, offload_to_cpu=True):
for key, (student_wrapper, teacher_wrapper) in self.wrappers_.items():
set_layer(key, student_wrapper, state.model)
set_layer(key, teacher_wrapper, state.teacher_model)
Module.set_submodule(key, student_wrapper, state.model)
Module.set_submodule(key, teacher_wrapper, state.teacher_model)

self.wrapped_kd_model_ = self._create_model_wrapper(
student_model=maybe_get_wrapped(state.model),
Expand All @@ -109,8 +109,8 @@ def on_finalize(self, state: State, **kwargs) -> bool:

with summon_full_params_context(state.teacher_model, offload_to_cpu=True):
for key, (student_wrapper, teacher_wrapper) in self.wrappers_.items():
set_layer(key, student_wrapper.layer, state.model)
set_layer(key, teacher_wrapper.layer, state.teacher_model)
Module.set_submodule(key, student_wrapper.layer, state.model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're sure we want to call a class method on Module here and not a method on an instance?

Module.set_submodule(key, teacher_wrapper.layer, state.teacher_model)
del student_wrapper
del teacher_wrapper

Expand Down
20 changes: 11 additions & 9 deletions src/llmcompressor/modifiers/obcq/sgpt_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from llmcompressor.modifiers.modifier import Modifier
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
get_prunable_layers,
match_targets,
)
from compressed_tensors import match_named_modules

def get_prunable_targets():
"""Return the list of prunable layer types."""
return ["Linear", "Conv1d", "Conv2d", "Conv3d", "QATLinear", "QATConv2d", "QATConv3d", "Conv1D"]

class SparsityModifierBase(Modifier):
"""
Expand Down Expand Up @@ -114,8 +115,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool:

# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(model)
layers = get_layers(self.sequential_targets, model)
self._target_layers = get_layers(
layers = match_named_modules(self.sequential_targets, model)
self._target_layers = match_named_modules(
self.targets, model
) # layers containing targets

Expand Down Expand Up @@ -149,11 +150,11 @@ def on_start(self, state: State, event: Event, **kwargs):
layer_sparsity = self.sparsity[index]
else:
layer_sparsity = self.sparsity

for name, module in get_prunable_layers(layer).items():
prunable_targets = get_prunable_targets()
for name, module in match_named_modules(layer, prunable_targets).items():
name = f"{layer_name}.{name}"

if match_targets(name, self.ignore)[0]:
if match_named_modules(name, self.ignore)[0]:
continue

# HACK: previously, embeddings were not quantized because they were not
Expand Down Expand Up @@ -210,7 +211,8 @@ def _infer_owl_layer_sparsity(

groups = {}
for name, layer in layers.items():
prunable_layers = get_prunable_layers(layer)
prunable_targets = get_prunable_targets()
prunable_layers = match_named_modules(layer, prunable_targets)
z = [
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
Expand Down
5 changes: 2 additions & 3 deletions src/llmcompressor/modifiers/pruning/constant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
LayerParamMasking,
param_mask_name,
)
from llmcompressor.utils.pytorch.module import get_layers_params

from compressed_tensors import match_named_parameters
__all__ = ["ConstantPruningModifier"]


Expand All @@ -29,7 +28,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
if not state.model:
return False

self.parameterized_layers_ = get_layers_params(self.targets, state.model)
self.parameterized_layers_ = match_named_parameters(self.targets, state.model)

for layer_param_name, parameterized_layer in self.parameterized_layers_.items():
self.add_mask(
Expand Down
5 changes: 2 additions & 3 deletions src/llmcompressor/modifiers/pruning/magnitude/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
PruningMaskCreatorArgs,
PruningMaskFactory,
)
from llmcompressor.utils.pytorch.module import get_layers_params

from compressed_tensors import match_named_parameters
__all__ = ["MagnitudePruningModifier"]


Expand Down Expand Up @@ -73,7 +72,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
self.mask_structure
)

self.parameterized_layers_ = get_layers_params(state.model)
self.parameterized_layers_ = match_named_parameters(state.model)

for layer_param_name, parameterized_layer in self.parameterized_layers_.items():
self.add_mask(
Expand Down
14 changes: 5 additions & 9 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@
handle_mapping_resolution_errors,
)
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.pytorch.module import (
get_layers,
get_matching_layer,
match_targets,
)

from compressed_tensors import match_modules_set
from compressed_tensors import match_named_modules
MINIMUM_SMOOTHING_SCALE = 1e-5


Expand Down Expand Up @@ -204,13 +200,13 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
"""
resolved_mappings = []
for to_balance, to_smooth in self.mappings:
to_smooth_layers = get_layers(to_smooth, model)
to_smooth_layers = match_named_modules(to_smooth, model)
for layer_name, smooth_layer in to_smooth_layers.items():
if not match_targets(layer_name, self.ignore)[0]:
if not match_named_modules(layer_name, self.ignore)[0]:
balance_layers = []
for balance_suffix in to_balance:
# find the submodule that matches the activation layer
_, balance_layer = get_matching_layer(
_, balance_layer = match_modules_set(
balance_suffix, layer_name, model
)
if balance_layer:
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/pipelines/layer_sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.pipelines.layer_sequential.helpers import (
capture_first_layer_intermediates,
match_modules,
maybe_inject_pos_embeddings,
to_next_layer_kwargs,
)
from compressed_tensors import match_named_modules
from llmcompressor.pipelines.registry import CalibrationPipeline
from llmcompressor.pipelines.sequential.helpers import (
dispatch_for_sequential,
Expand Down Expand Up @@ -67,7 +67,7 @@ def __call__(
# find layers
modifiers = session.lifecycle.recipe.modifiers
sequential_targets = get_sequential_targets(modifiers, model, dataset_args)
layers = match_modules(model, sequential_targets)
layers = match_named_modules(model, sequential_targets)

LifecycleCallbacks.calibration_epoch_start()

Expand Down
3 changes: 2 additions & 1 deletion src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.utils.helpers import calibration_forward_context, patch_attr
from llmcompressor.utils.pytorch.module import get_no_split_params
from compressed_tensors import match_named_modules

from .ast_helpers import autowrap_forwards

Expand Down Expand Up @@ -100,7 +101,7 @@ def trace_subgraphs(
:return: a list of Subgraphs in order of execution
"""
# find modules
targets = match_modules(model, sequential_targets)
targets = match_named_modules(model, sequential_targets)
ancestors = get_sequential_ancestors(model, targets)
offloaded = set(m for m in model.modules() if has_offloaded_params(m))

Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/compression/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from tqdm import tqdm

from llmcompressor.modifiers import Modifier
from llmcompressor.pytorch.utils import get_linear_layers
from llmcompressor.pytorch.utils.helpers import tensor_sparsity
from compressed_tensors import match_named_modules

__ALL__ = [
"tensor_follows_mask_structure",
Expand Down Expand Up @@ -76,8 +76,8 @@ def infer_sparsity_structure_from_model(model: torch.nn.Module) -> Optional[str]
# check for the common sparsity structures
structures = {"2:4"}
for sparsity_structure in structures:
linear_modules = get_linear_layers(model)
offloaded_params = get_state_dict_offloaded_model(model)
linear_modules = match_named_modules(model, ["Linear"])

linear_modules_with_sparsity_structure = [
tensor_follows_mask_structure(offloaded_params[f"{name}.weight"])
Expand Down
Loading