Skip to content

Commit 44dbf91

Browse files
committed
wip
Signed-off-by: Kyle Sayers <[email protected]>
1 parent fe01901 commit 44dbf91

File tree

7 files changed

+39
-32
lines changed

7 files changed

+39
-32
lines changed

examples/quantization_w4a16/llama3_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# Select number of samples. 512 samples is a good place to start.
1818
# Increasing the number of samples can improve accuracy.
19-
NUM_CALIBRATION_SAMPLES = 512
19+
NUM_CALIBRATION_SAMPLES = 12
2020
MAX_SEQUENCE_LENGTH = 2048
2121

2222
# Load dataset and preprocess.
@@ -57,6 +57,7 @@ def tokenize(sample):
5757
oneshot(
5858
model=model,
5959
dataset=ds,
60+
batch_size=12,
6061
recipe=recipe,
6162
max_seq_length=MAX_SEQUENCE_LENGTH,
6263
num_calibration_samples=NUM_CALIBRATION_SAMPLES,

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ class SmoothQuantMapping:
5555

5656
smooth_name: str
5757
smooth_layer: Module
58-
balance_names: List[str]
5958
balance_layers: List[Module]
6059

6160

@@ -216,20 +215,18 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
216215
to_smooth_layers = get_layers(to_smooth, model)
217216
for layer_name, smooth_layer in to_smooth_layers.items():
218217
if not match_targets(layer_name, self.ignore)[0]:
219-
balance_names = []
220218
balance_layers = []
221219
for balance_suffix in to_balance:
222220
# find the submodule that matches the activation layer
223221
balance_name, balance_layer = get_matching_layer(
224222
balance_suffix, layer_name, model
225223
)
226224
if balance_layer:
227-
balance_names.append(balance_name)
228225
balance_layers.append(balance_layer)
229226
# each mapping can contain multiple layers to balance, but only
230227
# one layer to smooth
231228
mapping = SmoothQuantMapping(
232-
layer_name, smooth_layer, balance_names, balance_layers
229+
layer_name, smooth_layer, balance_layers
233230
)
234231
resolved_mappings.append(mapping)
235232
return resolved_mappings

src/llmcompressor/modifiers/utils/pytorch_helpers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,11 @@
77
tensor operations for compression workflows.
88
"""
99

10-
from typing import TYPE_CHECKING, Dict
10+
from typing import Dict
1111

1212
import torch
1313
from torch.nn import Module
1414

15-
if TYPE_CHECKING:
16-
pass
17-
1815
__all__ = [
1916
"apply_pad_mask_to_batch",
2017
"is_moe_model",

src/llmcompressor/pipelines/basic/pipeline.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from compressed_tensors.utils import get_execution_device
77
from torch.utils.data.dataloader import DataLoader
88

9-
from llmcompressor.core import LifecycleCallbacks
9+
from llmcompressor.core import LifecycleCallbacks, active_session
1010
from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch
1111
from llmcompressor.pipelines.registry import CalibrationPipeline
1212
from llmcompressor.pytorch.utils.helpers import tensors_to_device
13-
from llmcompressor.utils import calibration_forward_context, dispatch_for_generation
13+
from llmcompressor.utils import calibration_forward_context, dispatch_for_generation, targets_lm_head, disable_lm_head
1414

1515
if TYPE_CHECKING:
1616
from llmcompressor.args.dataset_arguments import DatasetArguments
@@ -38,13 +38,20 @@ def __call__(
3838
:param dataloader: loads data for calibration
3939
:param dataset_args: dataset arguments relevant to pipelines
4040
"""
41+
session = active_session()
42+
modifiers = session.lifecycle.recipe.modifiers
43+
4144
dispatch_for_generation(model) # basic dispatch is identical to generation
4245
model_device = get_execution_device(model)
4346

4447
LifecycleCallbacks.calibration_epoch_start()
4548

4649
with contextlib.ExitStack() as stack:
4750
stack.enter_context(calibration_forward_context(model))
51+
# Optional disable lm_head
52+
if not targets_lm_head(model, modifiers):
53+
stack.enter_context(disable_lm_head(model))
54+
4855
for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
4956
batch = apply_pad_mask_to_batch(batch)
5057
batch = tensors_to_device(batch, model_device)

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
from llmcompressor.modifiers import Modifier
2525
from llmcompressor.modifiers.utils.hooks import HooksMixin
2626
from llmcompressor.pipelines.sequential.transformers_helpers import HFTracer
27-
from llmcompressor.transformers.compression.compressed_tensors_utils import (
28-
targets_embeddings,
29-
)
3027
from llmcompressor.utils.helpers import calibration_forward_context, patch_attr
3128
from llmcompressor.utils.pytorch.module import get_no_split_params
3229

@@ -40,7 +37,6 @@
4037
"Subgraph",
4138
"get_sequential_targets",
4239
"dispatch_for_sequential",
43-
"targets_lm_head",
4440
]
4541

4642

@@ -499,14 +495,6 @@ def get_sequential_targets(
499495
return sequential_targets
500496

501497

502-
def targets_lm_head(model: PreTrainedModel, modifiers: list[Modifier]) -> bool:
503-
targets = sum(
504-
(list(modifier.get_targets(model)) for modifier in modifiers), start=[]
505-
)
506-
507-
return targets_embeddings(model, targets, check_input=True, check_output=False)
508-
509-
510498
def add_line_numbers(text: str) -> str:
511499
lines = text.splitlines()
512500
numbered_lines = [f"{i + 1} {line}" for i, line in enumerate(lines)]

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
from llmcompressor.pipelines.sequential.helpers import (
1414
dispatch_for_sequential,
1515
get_sequential_targets,
16-
targets_lm_head,
1716
trace_subgraphs,
1817
)
1918
from llmcompressor.utils.helpers import (
2019
DISABLE_QAC_MODIFIERS,
2120
DisableQuantization,
2221
calibration_forward_context,
22+
targets_lm_head,
23+
disable_lm_head,
2324
)
2425

2526
if TYPE_CHECKING:
@@ -83,13 +84,15 @@ def __call__(
8384
type(mod).__name__ in DISABLE_QAC_MODIFIERS
8485
for mod in session.lifecycle.recipe.modifiers
8586
)
86-
skip_lm_head = not targets_lm_head(model, modifiers)
8787

8888
with contextlib.ExitStack() as stack:
89-
stack.enter_context(calibration_forward_context(model, skip_lm_head))
89+
stack.enter_context(calibration_forward_context(model))
9090
# Optionally disable quantization
9191
if not dataset_args.quantization_aware_calibration or disable_qac:
9292
stack.enter_context(DisableQuantization(model))
93+
# Optional disable lm_head
94+
if not targets_lm_head(model, modifiers):
95+
stack.enter_context(disable_lm_head(model))
9396

9497
# prepare intermediates cache
9598
activations = IntermediatesCache.from_dataloader(dataloader, model_device)

src/llmcompressor/utils/helpers.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from collections import OrderedDict
1919
from io import BytesIO
2020
from pathlib import Path
21-
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
21+
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, TYPE_CHECKING
2222
from urllib.parse import urlparse
2323

2424
import numpy
@@ -28,6 +28,9 @@
2828
from loguru import logger
2929
from transformers import PreTrainedModel
3030

31+
if TYPE_CHECKING:
32+
from llmcompressor.modifiers.modifier import Modifier
33+
3134
__all__ = [
3235
"ALL_TOKEN",
3336
"ALL_PRUNABLE_TOKEN",
@@ -68,6 +71,8 @@
6871
"calibration_forward_context",
6972
"patch_attr",
7073
"disable_hf_kernels",
74+
"disable_lm_head",
75+
"targets_lm_head",
7176
"DISABLE_QAC_MODIFIERS",
7277
]
7378

@@ -1042,23 +1047,20 @@ def disable_hf_kernels(module: torch.nn.Module):
10421047

10431048

10441049
@contextlib.contextmanager
1045-
def calibration_forward_context(model: torch.nn.Module, skip_lm_head: bool = False):
1050+
def calibration_forward_context(model: torch.nn.Module):
10461051
"""
10471052
Context in which all calibration forward passes should occur.
10481053
10491054
- Remove gradient calculations
10501055
- Disable the KV cache
10511056
- Disable train mode and enable eval mode
10521057
- Disable hf kernels which could bypass hooks
1053-
- Disable lm_head of model (optional)
10541058
"""
10551059
with contextlib.ExitStack() as stack:
10561060
stack.enter_context(torch.no_grad())
10571061
stack.enter_context(disable_cache(model))
10581062
stack.enter_context(eval_context(model))
10591063
stack.enter_context(disable_hf_kernels(model))
1060-
if skip_lm_head:
1061-
stack.enter_context(disable_lm_head(model))
10621064

10631065
yield
10641066

@@ -1091,7 +1093,19 @@ def disable_lm_head(model: torch.nn.Module):
10911093
yield
10921094

10931095

1094-
# TODO: deprecate
1096+
def targets_lm_head(model: PreTrainedModel, modifiers: list["Modifier"]) -> bool:
1097+
""" Returns True if the given modifiers target the lm_head """
1098+
from llmcompressor.transformers.compression.compressed_tensors_utils import (
1099+
targets_embeddings
1100+
)
1101+
1102+
targets = sum(
1103+
(list(modifier.get_targets(model)) for modifier in modifiers), start=[]
1104+
)
1105+
return targets_embeddings(model, targets, check_input=True, check_output=False)
1106+
1107+
1108+
10951109
@contextlib.contextmanager
10961110
def patch_attr(base: object, attr: str, value: Any):
10971111
"""

0 commit comments

Comments
 (0)