Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 47 additions & 15 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from compressed_tensors import has_offloaded_params
from compressed_tensors.quantization import find_name_or_class_matches
from loguru import logger
from torch.fx import Graph, GraphModule, Node
from torch.fx.graph import PythonCode
from torch.fx.proxy import Argument
Expand Down Expand Up @@ -77,15 +78,15 @@ def trace_subgraphs(
:param sample_input: inputs whose values will change during execution but whose
__len__, __bool__, and __contains__ values are assumed constant across batches
:param sequential_targets: list of patterns matching sequential targets
:param ignore: list of patterns matching modules to ignore during tracing
:param ignore: TODO: unused, in the future will specify functions and methods to
skip during tracing
:return: a list of Subgraphs in order of execution
"""
# find modules
sequential_targets = match_modules(model, sequential_targets)
ignore = match_modules(model, ignore)

# initialize arguments
tracer = get_tracer(model, sequential_targets, ignore)
tracer = get_tracer(model, sequential_targets)
concrete_args = populate_concrete_args(model, sample_input)

# trace
Expand Down Expand Up @@ -115,23 +116,31 @@ def trace_subgraphs(
return subgraphs


def get_tracer(
model: Module, sequential_targets: Set[Module], ignore: Set[Module]
) -> HFTracer:
def get_tracer(model: Module, sequential_targets: Set[Module]) -> HFTracer:
"""
Get a tracer specialized for the given model. The resulting tracer will not trace
inside of sequential targets, ignored targets, or offloaded modules.
inside of sequential targets, nor any modules which are not call graph ancestors of
sequential targets

Tracing within sequential targets and ignored targets is unnecessary, and tracing
within offloaded modules may result in meta tensors being added to the model graph
Tracing within sequential targets is unnecessary, and tracing within offloaded
modules may result in meta tensors being added to the model graph

:param model: model being traced
:param sequential_targets: modules which are sequential targets
:param ignore: modules which are ignored
"""
# TODO: redefine skip_trace_modules to all non-ancestors of sequential_targets
sequential_ancestors = get_sequential_ancestors(model, sequential_targets)
offloaded_modules = set(m for m in model.modules() if has_offloaded_params(m))
skip_trace_modules = sequential_targets | offloaded_modules | ignore

# check unlikely case that ancestors have direct params which are offloaded
offloaded_ancestors = offloaded_modules & sequential_ancestors
if offloaded_ancestors:
names = set(module.__class__.__name__ for module in offloaded_ancestors)
logger.warning(
"The following modules are call graph ancestors of sequential targets,"
f"but also contain offloaded modules: {names}.\n"
"These modules will not be traced, and any sequential target children will "
"be executed jointly, which may lead to OOM errors"
)

class SequentialTracer(HFTracer):
def create_arg(self, a: Any) -> Argument:
Expand All @@ -144,9 +153,7 @@ def create_arg(self, a: Any) -> Argument:
return super().create_arg(a)

def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool:
return module in skip_trace_modules or super().is_leaf_module(
module, module_qualified_name
)
return module not in sequential_ancestors or module in offloaded_modules

def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph:
if isinstance(root, Module):
Expand Down Expand Up @@ -400,3 +407,28 @@ def add_line_numbers(text: str) -> str:
lines = text.splitlines()
numbered_lines = [f"{i + 1} {line}" for i, line in enumerate(lines)]
return "\n".join(numbered_lines)


def get_sequential_ancestors(model: Module, targets: Set[Module]) -> Set[Module]:
"""
Find modules which are call graph ancestors of the given sequential targets

:param model: model containing sequential targets
:param targets: sequential targets to find ancestors of
:return: call graph ancestors of sequential targets
"""
ancestors = set()

def is_ancestor(module: Module) -> bool:
if module in ancestors or module in targets:
return True

# eagerly compute list in order to avoid early stopping and :. missing ancestors
_is_ancestor = any([is_ancestor(child) for child in module.children()])
if _is_ancestor:
ancestors.add(module)

return _is_ancestor

is_ancestor(model)
return ancestors
3 changes: 2 additions & 1 deletion src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def run_pipeline(
:param model: model being calibrated
:param dataloader: loads data for calibration
:param sequential_targets: patterns which match to the layer modules of the model
:param ignore: patterns which match to modules which should be ignored by tracing
:param ignore: TODO: unused, in the future will specify functions and methods to
skip during tracing
"""
# trace subgraphs
sample_input = next(iter(dataloader))
Expand Down
24 changes: 24 additions & 0 deletions tests/llmcompressor/pipelines/sequential/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch

from llmcompressor.pipelines.sequential.helpers import get_sequential_ancestors


class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.seq = torch.nn.Sequential(torch.nn.Linear(10, 20), torch.nn.ReLU())
self.fc = torch.nn.Linear(20, 5)

def forward(self, x):
x = self.seq(x)
return self.fc(x)


def test_get_sequential_ancestors():
model = DummyModel()

assert get_sequential_ancestors(model, set()) == set()
assert get_sequential_ancestors(model, {model}) == set()
assert get_sequential_ancestors(model, {model.fc}) == {model}
assert get_sequential_ancestors(model, {model.seq[0]}) == {model, model.seq}
assert get_sequential_ancestors(model, {model.seq[1]}) == {model, model.seq}