Skip to content

Commit 07c657e

Browse files
committed
fixes and formatting
Summary fix smoothquant logic to align with AWQ Signed-off-by: HDCharles <[email protected]>
1 parent 2b138a7 commit 07c657e

File tree

6 files changed

+62
-58
lines changed

6 files changed

+62
-58
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,17 @@
77
from compressed_tensors.utils import (
88
align_modules,
99
get_execution_device,
10+
get_lowest_common_ancestor_name,
1011
match_modules_set,
1112
match_named_modules,
1213
update_offload_parameter,
13-
get_lowest_common_ancestor_name,
1414
)
1515
from loguru import logger
1616
from pydantic import ConfigDict, PrivateAttr, model_validator
1717
from torch.nn import Module
18-
from tqdm import tqdm
1918
from torch.utils._pytree import tree_flatten
19+
from tqdm import tqdm
20+
2021
from llmcompressor.core import Event, EventType, State
2122
from llmcompressor.modifiers import Modifier
2223
from llmcompressor.modifiers.awq.mappings import (
@@ -30,7 +31,10 @@
3031
from llmcompressor.pipelines.cache import IntermediatesCache
3132
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
3233
from llmcompressor.utils.helpers import calibration_forward_context
33-
from llmcompressor.utils.pytorch.module import get_layer_by_name
34+
from llmcompressor.utils.pytorch.module import (
35+
get_layer_by_name,
36+
get_module_to_name_dict,
37+
)
3438

3539
__all__ = ["AWQModifier"]
3640

@@ -321,30 +325,20 @@ def _set_resolved_mappings(self, model: Module) -> None:
321325
repeat for model.layer.1 and so on
322326
"""
323327
resolved_mappings: list[ResolvedMapping] = []
324-
325-
module_to_name = {}
326-
for name, module in model.named_modules():
327-
if module in module_to_name:
328-
logger.info(
329-
f"Warning, {name} and {module_to_name[module]} both "
330-
"share the same module the same module, "
331-
"may have trouble resolving mappings."
332-
)
333-
module_to_name[module] = name
334-
328+
module_to_name = get_module_to_name_dict(model)
335329
for mapping in self.mappings:
336330
for smooth_layers, *nested_balance_layers in match_modules_set(
337331
model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore
338332
):
339-
assert len(smooth_layers)==1, (
340-
"AWQ mappings need to match a single smoothlayer for each mapping but got "
341-
f"{[module_to_name.get(smooth_layer) for smooth_layer in smooth_layers]} "
342-
f"when matching {mapping.smooth_layer}"
333+
assert len(smooth_layers) == 1, (
334+
"AWQ mappings need to match a single smoothlayer for each "
335+
f"mapping but got {[module_to_name.get(s) for s in smooth_layers]}"
336+
f" for mapping: {mapping}"
343337
)
344338
smooth_layer = smooth_layers[0]
345339
smooth_name = module_to_name.get(smooth_layer)
346340

347-
#[[b00, b01, b02...], [b10, b11, b12,...], ...] v
341+
# [[b00, b01, b02...], [b10, b11, b12,...], ...] v
348342
# [b00, b01, b02, ..., b10, b11, b12, ...]
349343
balance_layers = tree_flatten(nested_balance_layers)[0]
350344
balance_names = [
@@ -371,7 +365,9 @@ def _set_resolved_mappings(self, model: Module) -> None:
371365
else:
372366
# for multiple balance layers, find lowest common parent
373367
ancestor_name = get_lowest_common_ancestor_name(balance_names)
374-
ancestor_name, ancestor = get_lowest_non_module_list_ancestor(ancestor_name, model)
368+
ancestor_name, ancestor = get_lowest_non_module_list_ancestor(
369+
ancestor_name, model
370+
)
375371

376372
resolved_mappings.append(
377373
ResolvedMapping(
@@ -807,7 +803,7 @@ def _accumulate_mean(
807803

808804
def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]:
809805
"""
810-
Given a name and a model, finds lowest ancestor of
806+
Given a name and a model, finds lowest ancestor of
811807
named module that's not a ModuleList
812808
i.e. module_list.module_dict.module_list -> module_list.module_dict
813809
i.e. module_list.module_dict -> module_list.module_dict

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from dataclasses import dataclass
2-
from typing import Callable, Dict, List, Optional, Tuple, Union
2+
from typing import Callable, Dict, List, Optional, Tuple
33

44
import torch
5-
from compressed_tensors.utils import align_module_device, match_named_modules
5+
from compressed_tensors.utils import align_module_device, match_modules_set
66
from loguru import logger
77
from pydantic import ConfigDict, Field
88
from torch.nn import Module
9+
from torch.utils._pytree import tree_flatten
910

1011
from llmcompressor.core import Event, EventType, State
1112
from llmcompressor.modifiers import Modifier
@@ -14,7 +15,7 @@
1415
handle_mapping_resolution_errors,
1516
)
1617
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
17-
from llmcompressor.utils.pytorch.module import get_layer_by_name
18+
from llmcompressor.utils.pytorch.module import get_module_to_name_dict
1819

1920
MINIMUM_SMOOTHING_SCALE = 1e-5
2021

@@ -95,7 +96,7 @@ class SmoothQuantModifier(Modifier):
9596
"""
9697

9798
smoothing_strength: float = 0.5
98-
mappings: Optional[List[Union[Tuple, List]]] = None
99+
mappings: Optional[List[Tuple[List[str], str]]] = None
99100
ignore: Optional[List[str]] = None
100101
num_calibration_steps: Optional[int] = None
101102
calibration_function: Optional[Callable] = None
@@ -198,27 +199,22 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
198199
be balanced.
199200
"""
200201
resolved_mappings = []
201-
for to_balance, to_smooth in self.mappings:
202-
to_smooth_list = [to_smooth] if isinstance(to_smooth, str) else to_smooth
203-
204-
for smooth_name, smooth_layer in match_named_modules(
205-
model, to_smooth_list, self.ignore
202+
module_to_name = get_module_to_name_dict(model)
203+
for mapping in self.mappings:
204+
for *nested_balance_layers, smooth_layers in match_modules_set(
205+
model, tree_flatten(mapping)[0], self.ignore
206206
):
207-
# Search for balance layers within the parent scope
208-
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
209-
smooth_parent = get_layer_by_name(smooth_parent_name, model)
210-
211-
balance_layers = [
212-
balance_layer
213-
for _, balance_layer in match_named_modules(
214-
smooth_parent, to_balance, self.ignore
215-
)
216-
]
217-
218-
if balance_layers:
219-
resolved_mappings.append(
220-
SmoothQuantMapping(smooth_name, smooth_layer, balance_layers)
221-
)
207+
assert len(smooth_layers) == 1, (
208+
"SmoothQuant mappings must match a single smooth layer for each "
209+
f"mapping but got {[module_to_name.get(s) for s in smooth_layers]}"
210+
f" for mapping: {mapping}"
211+
)
212+
smooth_layer = smooth_layers[0]
213+
smooth_name = module_to_name.get(smooth_layers[0])
214+
balance_layers = tree_flatten(nested_balance_layers)[0]
215+
resolved_mappings.append(
216+
SmoothQuantMapping(smooth_name, smooth_layer, balance_layers)
217+
)
222218

223219
return resolved_mappings
224220

src/llmcompressor/modifiers/smoothquant/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import functools
22
from collections import namedtuple
3-
from typing import Dict, List, Tuple, Union
3+
from typing import Dict, List, Tuple
44

55
from loguru import logger
66

@@ -10,7 +10,7 @@
1010
"DEFAULT_SMOOTHQUANT_MAPPINGS",
1111
]
1212

13-
LayerMapType = Tuple[Union[List[str], str], Union[List[str], str]]
13+
LayerMapType = Tuple[List[str], str]
1414
LayerMap: LayerMapType = namedtuple("LayerMap", ["balance_layers", "smooth_layers"])
1515

1616
DEFAULT_SMOOTHQUANT_MAPPINGS: List[LayerMap] = [

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
TransformScheme,
1010
apply_transform_config,
1111
)
12-
from torch.utils._pytree import tree_flatten
1312
from compressed_tensors.utils import TorchDtype, get_head_dim
1413
from pydantic import Field, ValidationInfo, field_validator
14+
from torch.utils._pytree import tree_flatten
1515
from transformers import PreTrainedModel
1616

1717
from llmcompressor.core import Event, EventType, State
@@ -204,7 +204,7 @@ def _fuse_norms(self, model: PreTrainedModel):
204204
for mapping in self.norm_mappings:
205205
for norm, *linears in match_modules_set(
206206
model, (mapping.norm, *mapping.linears)
207-
):
207+
):
208208
# match_modules_set returns a list of lists
209209
assert len(norm) == 1
210210
fuse_norm_linears(norm[0], tree_flatten(linears)[0])

src/llmcompressor/utils/pytorch/module.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from compressed_tensors import InternalModule
1212
from compressed_tensors.quantization.utils import is_module_quantized
13+
from loguru import logger
1314
from torch.nn import Linear, Module, Parameter
1415
from torch.nn.modules.conv import _ConvNd
1516
from transformers import PreTrainedModel
@@ -369,3 +370,16 @@ def get_layer_by_name(layer_name: str, module: Module) -> Module:
369370
if not layer_name:
370371
return module
371372
return attrgetter(layer_name)(module)
373+
374+
375+
def get_module_to_name_dict(model: Module) -> dict[Module:str]:
376+
module_to_name = {}
377+
for name, module in model.named_modules():
378+
if module in module_to_name:
379+
logger.info(
380+
f"Warning, {name} and {module_to_name[module]} both "
381+
"share the same module the same module, "
382+
"may have trouble resolving mappings."
383+
)
384+
module_to_name[module] = name
385+
return module_to_name

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,21 +222,20 @@ def test_get_lowest_non_module_list_ancestor():
222222
)
223223
}
224224
)
225-
226-
ancestor_name, ancestor = get_lowest_non_module_list_ancestor(
227-
"", model
228-
)
225+
226+
ancestor_name, ancestor = get_lowest_non_module_list_ancestor("", model)
229227
assert ancestor_name == "" and ancestor == model
230228

231-
ancestor_name, ancestor = get_lowest_non_module_list_ancestor(
232-
"experts", model
233-
)
229+
ancestor_name, ancestor = get_lowest_non_module_list_ancestor("experts", model)
234230
assert ancestor_name == "" and ancestor == model
235231

236232
ancestor_name, ancestor = get_lowest_non_module_list_ancestor(
237233
"experts.1.gate_proj", model
238234
)
239-
assert ancestor_name == "experts.1.gate_proj" and ancestor == model["experts"][1]["gate_proj"]
235+
assert (
236+
ancestor_name == "experts.1.gate_proj"
237+
and ancestor == model["experts"][1]["gate_proj"]
238+
)
240239

241240

242241
@pytest.mark.unit
@@ -298,4 +297,3 @@ def test_moe_multiple_balance_layers():
298297

299298
assert mapping.parent_name == "layer.mlp"
300299
assert mapping.parent == mlp
301-

0 commit comments

Comments
 (0)