Skip to content

Commit 56047ac

Browse files
authored
[Tracing] Skip non-ancestors of sequential targets (#1389)
## Purpose ## * Reduce model support burden by skipping any modules which are not call graph ancestors of the sequential targets * Rather than requiring the user to specify a list of ignored modules, only trace what is necessary to disjointly execute sequential targets * In the future, the ignore field will be used to skip untraceable function/method names * This change does not change functionality because all ignored modules are already non-ancestors of sequential targets ## Changes ## * Remove `ignore` modules requirement (all ignored modules are already non-ancestors of sequential targets) * Implement `get_sequential_ancestors` which returns all ancestors of the sequential targets * Modify tracer to skip anything that is not a sequential ancestor or has offloaded modules * The two sets rarely overlap, and when they do, the module is skipped for safety and the user is warned ## Testing ## * Added tests for `get_sequential_ancestors` * #1335 ## Follow ups ## * #1390 --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 69d0963 commit 56047ac

File tree

3 files changed

+73
-16
lines changed

3 files changed

+73
-16
lines changed

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from compressed_tensors import has_offloaded_params
77
from compressed_tensors.quantization import find_name_or_class_matches
8+
from loguru import logger
89
from torch.fx import Graph, GraphModule, Node
910
from torch.fx.graph import PythonCode
1011
from torch.fx.proxy import Argument
@@ -77,15 +78,15 @@ def trace_subgraphs(
7778
:param sample_input: inputs whose values will change during execution but whose
7879
__len__, __bool__, and __contains__ values are assumed constant across batches
7980
:param sequential_targets: list of patterns matching sequential targets
80-
:param ignore: list of patterns matching modules to ignore during tracing
81+
:param ignore: TODO: unused, in the future will specify functions and methods to
82+
skip during tracing
8183
:return: a list of Subgraphs in order of execution
8284
"""
8385
# find modules
8486
sequential_targets = match_modules(model, sequential_targets)
85-
ignore = match_modules(model, ignore)
8687

8788
# initialize arguments
88-
tracer = get_tracer(model, sequential_targets, ignore)
89+
tracer = get_tracer(model, sequential_targets)
8990
concrete_args = populate_concrete_args(model, sample_input)
9091

9192
# trace
@@ -115,23 +116,31 @@ def trace_subgraphs(
115116
return subgraphs
116117

117118

118-
def get_tracer(
119-
model: Module, sequential_targets: Set[Module], ignore: Set[Module]
120-
) -> HFTracer:
119+
def get_tracer(model: Module, sequential_targets: Set[Module]) -> HFTracer:
121120
"""
122121
Get a tracer specialized for the given model. The resulting tracer will not trace
123-
inside of sequential targets, ignored targets, or offloaded modules.
122+
inside of sequential targets, nor any modules which are not call graph ancestors of
123+
sequential targets
124124
125-
Tracing within sequential targets and ignored targets is unnecessary, and tracing
126-
within offloaded modules may result in meta tensors being added to the model graph
125+
Tracing within sequential targets is unnecessary, and tracing within offloaded
126+
modules may result in meta tensors being added to the model graph
127127
128128
:param model: model being traced
129129
:param sequential_targets: modules which are sequential targets
130-
:param ignore: modules which are ignored
131130
"""
132-
# TODO: redefine skip_trace_modules to all non-ancestors of sequential_targets
131+
sequential_ancestors = get_sequential_ancestors(model, sequential_targets)
133132
offloaded_modules = set(m for m in model.modules() if has_offloaded_params(m))
134-
skip_trace_modules = sequential_targets | offloaded_modules | ignore
133+
134+
# check unlikely case that ancestors have direct params which are offloaded
135+
offloaded_ancestors = offloaded_modules & sequential_ancestors
136+
if offloaded_ancestors:
137+
names = set(module.__class__.__name__ for module in offloaded_ancestors)
138+
logger.warning(
139+
"The following modules are call graph ancestors of sequential targets,"
140+
f"but also contain offloaded modules: {names}.\n"
141+
"These modules will not be traced, and any sequential target children will "
142+
"be executed jointly, which may lead to OOM errors"
143+
)
135144

136145
class SequentialTracer(HFTracer):
137146
def create_arg(self, a: Any) -> Argument:
@@ -144,9 +153,7 @@ def create_arg(self, a: Any) -> Argument:
144153
return super().create_arg(a)
145154

146155
def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool:
147-
return module in skip_trace_modules or super().is_leaf_module(
148-
module, module_qualified_name
149-
)
156+
return module not in sequential_ancestors or module in offloaded_modules
150157

151158
def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph:
152159
if isinstance(root, Module):
@@ -400,3 +407,28 @@ def add_line_numbers(text: str) -> str:
400407
lines = text.splitlines()
401408
numbered_lines = [f"{i + 1} {line}" for i, line in enumerate(lines)]
402409
return "\n".join(numbered_lines)
410+
411+
412+
def get_sequential_ancestors(model: Module, targets: Set[Module]) -> Set[Module]:
413+
"""
414+
Find modules which are call graph ancestors of the given sequential targets
415+
416+
:param model: model containing sequential targets
417+
:param targets: sequential targets to find ancestors of
418+
:return: call graph ancestors of sequential targets
419+
"""
420+
ancestors = set()
421+
422+
def is_ancestor(module: Module) -> bool:
423+
if module in ancestors or module in targets:
424+
return True
425+
426+
# eagerly compute list in order to avoid early stopping and :. missing ancestors
427+
_is_ancestor = any([is_ancestor(child) for child in module.children()])
428+
if _is_ancestor:
429+
ancestors.add(module)
430+
431+
return _is_ancestor
432+
433+
is_ancestor(model)
434+
return ancestors

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def run_pipeline(
4444
:param model: model being calibrated
4545
:param dataloader: loads data for calibration
4646
:param sequential_targets: patterns which match to the layer modules of the model
47-
:param ignore: patterns which match to modules which should be ignored by tracing
47+
:param ignore: TODO: unused, in the future will specify functions and methods to
48+
skip during tracing
4849
"""
4950
# trace subgraphs
5051
sample_input = next(iter(dataloader))
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
3+
from llmcompressor.pipelines.sequential.helpers import get_sequential_ancestors
4+
5+
6+
class DummyModel(torch.nn.Module):
7+
def __init__(self):
8+
super().__init__()
9+
self.seq = torch.nn.Sequential(torch.nn.Linear(10, 20), torch.nn.ReLU())
10+
self.fc = torch.nn.Linear(20, 5)
11+
12+
def forward(self, x):
13+
x = self.seq(x)
14+
return self.fc(x)
15+
16+
17+
def test_get_sequential_ancestors():
18+
model = DummyModel()
19+
20+
assert get_sequential_ancestors(model, set()) == set()
21+
assert get_sequential_ancestors(model, {model}) == set()
22+
assert get_sequential_ancestors(model, {model.fc}) == {model}
23+
assert get_sequential_ancestors(model, {model.seq[0]}) == {model, model.seq}
24+
assert get_sequential_ancestors(model, {model.seq[1]}) == {model, model.seq}

0 commit comments

Comments
 (0)