Skip to content

Commit fe01901

Browse files
committed
get_targets interface
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 40cacb3 commit fe01901

File tree

17 files changed

+179
-145
lines changed

17 files changed

+179
-145
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,7 @@ class DatasetArguments(CustomDatasetArguments):
183183
),
184184
},
185185
)
186-
batch_size: int = field(
187-
default=1,
188-
metadata={
189-
"help": "TODO"
190-
}
191-
)
186+
batch_size: int = field(default=1, metadata={"help": "TODO"})
192187
# --- pipeline arguments --- #
193188
pipeline: str | None = field(
194189
default="independent",

src/llmcompressor/modifiers/autoround/base.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from llmcompressor.modifiers.quantization.calibration import apply_calibration_status
2222
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
2323
from llmcompressor.transformers.compression.compressed_tensors_utils import (
24-
untie_if_target_shared_embedding,
24+
targets_embeddings,
25+
untie_word_embeddings,
2526
)
2627
from llmcompressor.utils.pytorch.module import get_no_split_params
2728

@@ -109,7 +110,6 @@ class AutoRoundModifier(Modifier, QuantizationMixin):
109110
enable_torch_compile: bool = True
110111

111112
# private variables
112-
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
113113
_all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict)
114114
_q_input: Optional[torch.Tensor] = PrivateAttr(default=None)
115115

@@ -124,10 +124,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
124124
QuantizationMixin.initialize_quantization(self, state.model)
125125

126126
# prepare module names
127-
self._module_names = {
128-
m: name
129-
for name, m in match_named_modules(state.model, self.targets, self.ignore)
130-
}
131127
self._add_temporary_names(state.model)
132128
# freeze all model parameters
133129
for _, param in state.model.named_parameters():
@@ -142,7 +138,8 @@ def start_calibration(self, model: torch.nn.Module):
142138
143139
:param model: model to prepare for calibration
144140
"""
145-
untie_if_target_shared_embedding(model, self._module_names.keys())
141+
if targets_embeddings(model, self.get_targets(model)):
142+
untie_word_embeddings(model)
146143

147144
for _, module in match_named_modules(model, self.targets, self.ignore):
148145
# Note: No need to register observers for auto-round

src/llmcompressor/modifiers/awq/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from compressed_tensors.utils import (
88
align_modules,
99
get_execution_device,
10+
match_modules_set,
1011
match_named_modules,
1112
update_offload_parameter,
1213
)
@@ -26,6 +27,7 @@
2627
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
2728
from llmcompressor.modifiers.utils.hooks import HooksMixin
2829
from llmcompressor.pipelines.cache import IntermediatesCache
30+
from llmcompressor.typing import NamedModules
2931
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
3032
from llmcompressor.utils.helpers import calibration_forward_context
3133
from llmcompressor.utils.pytorch.module import get_layer_by_name
@@ -306,6 +308,12 @@ def on_finalize(self, state: State, **kwargs) -> bool:
306308

307309
return True
308310

311+
def get_targets(self, model: torch.nn.Module) -> NamedModules:
312+
for mapping in self.mappings:
313+
yield from match_modules_set(
314+
model, (*mapping.balance_layers, mapping.smooth_layer)
315+
)
316+
309317
def _set_resolved_mappings(self, model: Module) -> None:
310318
"""
311319
Transforms the list of activations to smooth and their corresponding weights

src/llmcompressor/modifiers/distillation/output/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List, Tuple, Union
22

3+
import torch
34
from torch.nn import Module
45

56
from llmcompressor.core import Event, EventType, State
@@ -9,6 +10,7 @@
910
KDModelWrapper,
1011
KDModuleWrapper,
1112
)
13+
from llmcompressor.typing import NamedModules
1214
from llmcompressor.utils.fsdp.context import summon_full_params_context
1315
from llmcompressor.utils.fsdp.helpers import maybe_get_wrapped, set_wrapped_model
1416
from llmcompressor.utils.pytorch.module import get_layers, set_layer
@@ -138,6 +140,16 @@ def on_end(self, state: State, event: Event, **kwargs):
138140
teacher_wrapper.kd_enabled = False
139141
self.wrapped_kd_model_.kd_enabled = False
140142

143+
def get_targets(self, model: torch.nn.Module) -> NamedModules:
144+
module_targets = dict()
145+
targets = self.targets if isinstance(self.targets, list) else [self.targets]
146+
for target in targets:
147+
# only return targets of student model, not teacher model
148+
target = target[0] if isinstance(target, tuple) else target
149+
module_targets.update(get_layers(target, model))
150+
151+
return module_targets.items()
152+
141153
def _create_model_wrapper(
142154
self, student_model: Module, teacher_model: Module, state: State
143155
) -> KDModelWrapper:

src/llmcompressor/modifiers/modifier.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import abstractmethod
2-
from typing import Optional
2+
from typing import Iterable, Optional
33

4+
import torch
45
from pydantic import ConfigDict
56

67
from llmcompressor.core.events import Event, EventType
@@ -238,3 +239,12 @@ def on_event(self, state: State, event: Event, **kwargs):
238239
:param kwargs: Additional arguments for updating the model
239240
"""
240241
pass
242+
243+
def get_targets(
244+
self, model: torch.nn.Module
245+
) -> Iterable[tuple[str, torch.nn.Module]]:
246+
"""
247+
Return all of the named modules which will be updated by this modifier. This
248+
function can only be called after the modifier has been initialized.
249+
"""
250+
raise NotImplementedError()

src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from llmcompressor.core import Event, EventType, State
1313
from llmcompressor.modifiers.modifier import Modifier
1414
from llmcompressor.modifiers.utils.hooks import HooksMixin
15+
from llmcompressor.typing import NamedModules
1516
from llmcompressor.utils.pytorch.module import (
1617
get_layers,
1718
get_no_split_params,
@@ -192,6 +193,9 @@ def on_end(self, state: State, event: Event, **kwargs):
192193
self.ended_ = True
193194
self.remove_hooks()
194195

196+
def get_targets(self, model: torch.nn.Module) -> NamedModules:
197+
return get_layers(self.targets, model).items()
198+
195199
def _infer_sequential_targets(self, model: torch.nn.Module) -> str | list[str]:
196200
match self.sequential_targets:
197201
case None:

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
__all__ = ["GPTQModifier"]
3333

3434

35-
class GPTQModifier(Modifier, QuantizationMixin):
35+
class GPTQModifier(QuantizationMixin, Modifier):
3636
"""
3737
Implements the GPTQ algorithm from https://arxiv.org/abs/2210.17323. This modifier
3838
uses activations to calibrate a hessian matrix, which is then used to determine

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@
3535
)
3636
from llmcompressor.modifiers.utils.hooks import HooksMixin
3737
from llmcompressor.transformers.compression.compressed_tensors_utils import (
38-
untie_if_target_shared_embedding,
38+
targets_embeddings,
39+
untie_word_embeddings,
3940
)
41+
from llmcompressor.typing import NamedModules
4042

4143
__all__ = ["QuantizationMixin"]
4244

@@ -182,11 +184,8 @@ def start_calibration(self, model: torch.nn.Module):
182184
183185
:param model: model to prepare for calibration
184186
"""
185-
186-
matched_module_generator = (
187-
x[1] for x in match_named_modules(model, self.resolved_targets, self.ignore)
188-
)
189-
untie_if_target_shared_embedding(model, matched_module_generator)
187+
if targets_embeddings(model, self.get_targets(model)):
188+
untie_word_embeddings(model)
190189

191190
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
192191
self._initialize_observers(module)
@@ -263,6 +262,9 @@ def resolve_quantization_config(self) -> QuantizationConfig:
263262
ignore=ignore,
264263
)
265264

265+
def get_targets(self, model: torch.nn.Module) -> NamedModules:
266+
return match_named_modules(model, self.resolved_targets, self.ignore)
267+
266268
def _initialize_observers(self, module: torch.nn.Module):
267269
if not hasattr(module, "quantization_scheme"):
268270
return

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Callable, Dict, List, Optional, Tuple, Union
33

44
import torch
5-
from compressed_tensors.utils import align_module_device
5+
from compressed_tensors.utils import align_module_device, match_modules_set
66
from loguru import logger
77
from pydantic import ConfigDict, Field
88
from torch.nn import Module
@@ -13,6 +13,7 @@
1313
get_layer_mappings_from_architecture,
1414
handle_mapping_resolution_errors,
1515
)
16+
from llmcompressor.typing import NamedModules
1617
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
1718
from llmcompressor.utils.pytorch.module import (
1819
get_layers,
@@ -54,6 +55,7 @@ class SmoothQuantMapping:
5455

5556
smooth_name: str
5657
smooth_layer: Module
58+
balance_names: List[str]
5759
balance_layers: List[Module]
5860

5961

@@ -178,6 +180,13 @@ def on_finalize(self, state: State, **kwargs) -> bool:
178180

179181
return True
180182

183+
def get_targets(self, model: torch.nn.Module) -> NamedModules:
184+
if not self.initialized_:
185+
raise ValueError("Cannot get targets before modifier has been initialized")
186+
187+
for balance_targets, smooth_target in self.mappings:
188+
yield from match_modules_set(model, (*balance_targets, smooth_target))
189+
181190
def _infer_mappings_from_model(
182191
self,
183192
model: Module,
@@ -207,18 +216,20 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
207216
to_smooth_layers = get_layers(to_smooth, model)
208217
for layer_name, smooth_layer in to_smooth_layers.items():
209218
if not match_targets(layer_name, self.ignore)[0]:
219+
balance_names = []
210220
balance_layers = []
211221
for balance_suffix in to_balance:
212222
# find the submodule that matches the activation layer
213-
_, balance_layer = get_matching_layer(
223+
balance_name, balance_layer = get_matching_layer(
214224
balance_suffix, layer_name, model
215225
)
216226
if balance_layer:
227+
balance_names.append(balance_name)
217228
balance_layers.append(balance_layer)
218229
# each mapping can contain multiple layers to balance, but only
219230
# one layer to smooth
220231
mapping = SmoothQuantMapping(
221-
layer_name, smooth_layer, balance_layers
232+
layer_name, smooth_layer, balance_names, balance_layers
222233
)
223234
resolved_mappings.append(mapping)
224235
return resolved_mappings

src/llmcompressor/modifiers/transform/quip/base.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
from llmcompressor.core import Event, EventType, State
1414
from llmcompressor.modifiers import Modifier
1515
from llmcompressor.transformers.compression.compressed_tensors_utils import (
16-
untie_if_target_shared_embedding,
16+
targets_embeddings,
17+
untie_word_embeddings,
1718
)
19+
from llmcompressor.typing import NamedModules
1820

1921
__all__ = ["QuIPModifier"]
2022

@@ -102,18 +104,13 @@ def on_initialize(self, state: State, **kwargs) -> bool:
102104

103105
def on_start(self, state: State, event: Event, **kwargs):
104106
self.started_ = True
107+
model = state.model
105108

106-
def matched_module_generator():
107-
for scheme in self.transform_config.config_groups.values():
108-
for arg in scheme.apply:
109-
gen = match_named_modules(state.model, arg.targets, arg.ignore)
110-
for _, module in gen:
111-
yield module
109+
# untie embeddings if they will be targeted by transforms
110+
if targets_embeddings(model, self.get_targets(model)):
111+
untie_word_embeddings(model)
112112

113-
# Untie embeddings if they will be targeted by transforms
114-
untie_if_target_shared_embedding(state.model, matched_module_generator())
115-
116-
apply_transform_config(state.model, self.transform_config)
113+
apply_transform_config(model, self.transform_config)
117114

118115
def on_event(self, state: State, event: Event, **kwargs):
119116
if event.type_ == EventType.CALIBRATION_EPOCH_START:
@@ -136,6 +133,17 @@ def on_finalize(self, state: State, **kwargs) -> bool:
136133

137134
return True
138135

136+
def get_targets(self, model: torch.nn.Module) -> NamedModules:
137+
if not self.initialized_:
138+
raise ValueError("Cannot get targets before modifier has been initialized")
139+
140+
return [
141+
(name, module)
142+
for scheme in self.transform_config.config_groups.values()
143+
for arg in scheme.apply
144+
for name, module in match_named_modules(model, arg.targets, arg.ignore)
145+
]
146+
139147
def _create_config(self) -> TransformConfig:
140148
config_groups = dict()
141149
if "v" in self.rotations:

0 commit comments

Comments
 (0)