Skip to content

Commit 830c904

Browse files
committed
expand observers to calculate gparams, add example for activations
1 parent dabff0a commit 830c904

File tree

8 files changed

+201
-17
lines changed

8 files changed

+201
-17
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
7+
MODEL_ID = "meta-llama/Llama-3.3-70B-Instruct"
8+
# MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
9+
10+
# Load model.
11+
model = AutoModelForCausalLM.from_pretrained(
12+
MODEL_ID, device_map="auto", torch_dtype="auto"
13+
)
14+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15+
16+
17+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
18+
DATASET_SPLIT = "train_sft"
19+
20+
# Select number of samples. 512 samples is a good place to start.
21+
# Increasing the number of samples can improve accuracy.
22+
NUM_CALIBRATION_SAMPLES = 20
23+
MAX_SEQUENCE_LENGTH = 2048
24+
25+
# Load dataset and preprocess.
26+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
27+
ds = ds.shuffle(seed=42)
28+
29+
30+
def preprocess(example):
31+
return {
32+
"text": tokenizer.apply_chat_template(
33+
example["messages"],
34+
tokenize=False,
35+
)
36+
}
37+
38+
39+
ds = ds.map(preprocess)
40+
41+
42+
# Tokenize inputs.
43+
def tokenize(sample):
44+
return tokenizer(
45+
sample["text"],
46+
padding=False,
47+
max_length=MAX_SEQUENCE_LENGTH,
48+
truncation=True,
49+
add_special_tokens=False,
50+
)
51+
52+
53+
ds = ds.map(tokenize, remove_columns=ds.column_names)
54+
55+
# Configure the quantization algorithm and scheme.
56+
# In this case, we:
57+
# * quantize the weights to fp4 with per group 16 via ptq
58+
# * calibrate a global_scale for activations, which will be used to
59+
# quantize activations to fp4 on the fly
60+
recipe = QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"])
61+
62+
# Apply quantization.
63+
oneshot(
64+
model=model,
65+
dataset=ds,
66+
recipe=recipe,
67+
max_seq_length=MAX_SEQUENCE_LENGTH,
68+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
69+
)
70+
71+
# Save to disk in compressed-tensors format.
72+
SAVE_DIR = MODEL_ID.split("/")[1] + "-NVFP4-v4"
73+
model.save_pretrained(SAVE_DIR, save_compressed=True)
74+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
KVCacheScaleType,
66
QuantizationScheme,
77
QuantizationStatus,
8+
QuantizationStrategy,
89
)
910
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
1011
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
@@ -84,14 +85,48 @@ def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor]
8485
"Must provide a value to observe if not using weight observer"
8586
)
8687

88+
quantization_scheme = getattr(module, "quantization_scheme", None)
89+
should_calculate_gparam = False
90+
should_calculate_qparams = True
91+
92+
# TODO: will update to be the case for both weight and input in a follow-up
93+
# weight global calculate is currently done in ct right now; s
94+
# should be moved here to unify global scale calculations
95+
if (
96+
quantization_scheme.strategy == QuantizationStrategy.TENSOR_GROUP
97+
and base_name == "input"
98+
):
99+
should_calculate_gparam = True
100+
should_calculate_qparams = False
101+
87102
observer = getattr(module, f"{base_name}_observer")
88-
updated_scale, updated_zero_point = observer(
89-
value, g_idx=g_idx, global_scale=global_scale
103+
observer_outputs = observer(
104+
value,
105+
g_idx=g_idx,
106+
global_scale=global_scale,
107+
should_calculate_gparam=should_calculate_gparam,
108+
should_calculate_qparams=should_calculate_qparams,
90109
)
91110

92-
# update scale and zero point
93-
update_parameter_data(module, updated_scale, f"{base_name}_scale")
94-
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
111+
if should_calculate_qparams:
112+
if should_calculate_gparam:
113+
updated_scale, updated_zero_point, updated_global_scale = (
114+
observer_outputs
115+
)
116+
else:
117+
updated_scale, updated_zero_point = observer_outputs
118+
else:
119+
updated_global_scale = observer_outputs
120+
121+
if should_calculate_gparam:
122+
update_parameter_data(
123+
module, updated_global_scale, f"{base_name}_global_scale"
124+
)
125+
126+
if should_calculate_qparams:
127+
# update scale and zero point
128+
update_parameter_data(module, updated_scale, f"{base_name}_scale")
129+
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
95130

96131

97132
def update_weight_zp_scale(module: Module):

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from compressed_tensors.quantization import (
5+
DynamicType,
56
QuantizationArgs,
67
QuantizationConfig,
78
QuantizationScheme,
@@ -212,7 +213,10 @@ def _initialize_observers(self, module: torch.nn.Module):
212213
return
213214

214215
scheme: QuantizationScheme = module.quantization_scheme
215-
input = scheme.input_activations and not scheme.input_activations.dynamic
216+
input = scheme.input_activations and scheme.input_activations.dynamic in (
217+
False,
218+
DynamicType.LOCAL,
219+
)
216220
weight = scheme.weights is not None
217221
output = scheme.output_activations and not scheme.output_activations.dynamic
218222
is_attention = is_attention_module(module)
@@ -241,7 +245,10 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
241245
continue
242246

243247
scheme: QuantizationScheme = module.quantization_scheme
244-
input = scheme.input_activations and not scheme.input_activations.dynamic
248+
input = scheme.input_activations and scheme.input_activations.dynamic in (
249+
False,
250+
DynamicType.LOCAL,
251+
)
245252
output = scheme.output_activations and not scheme.output_activations.dynamic
246253
is_attention = is_attention_module(module)
247254

src/llmcompressor/observers/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,14 @@ def post_calculate_qparams(self) -> None:
7373
Run any logic specific to its observers after running calculate_qparams
7474
"""
7575

76+
# TODO: use a different name?
7677
def get_qparams(
7778
self,
7879
observed: Optional[Tensor] = None,
7980
g_idx: Optional[Tensor] = None,
8081
global_scale: Optional[Tensor] = None,
82+
should_calculate_gparam: bool = False,
83+
should_calculate_qparams: bool = True,
8184
) -> Tuple[FloatTensor, IntTensor]:
8285
"""
8386
Convenience function to wrap overwritten calculate_qparams
@@ -101,6 +104,14 @@ def get_qparams(
101104
QuantizationStrategy.TENSOR_GROUP,
102105
QuantizationStrategy.GROUP,
103106
):
107+
# Global params are for the entire tensor
108+
if should_calculate_gparam:
109+
return self.calculate_qparams(
110+
observed,
111+
should_calculate_gparam=True,
112+
should_calculate_qparams=False,
113+
)
114+
104115
rows = observed.shape[0]
105116
columns = observed.shape[1]
106117
num_groups = int(ceil(columns / group_size))
@@ -137,7 +148,7 @@ def get_qparams(
137148
observed[:, start:end],
138149
0,
139150
tensor_id=group_index,
140-
global_scale=global_scale
151+
global_scale=global_scale,
141152
)
142153

143154
self._scale[:, group_index] = scale.squeeze(1)

src/llmcompressor/observers/helpers.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from collections import Counter
2+
from typing import Optional
23

34
import torch
5+
from compressed_tensors.quantization.quant_args import (
6+
FP4_E2M1_DATA,
7+
FP8_E4M3_DATA,
8+
FloatArgs,
9+
)
410

5-
__all__ = ["get_observer_token_count"]
11+
__all__ = ["get_observer_token_count", "calculate_gparam"]
612

713

814
def get_observer_token_count(module: torch.nn.Module) -> Counter:
@@ -20,3 +26,26 @@ def get_observer_token_count(module: torch.nn.Module) -> Counter:
2026
module._num_observed_tokens
2127
)
2228
return token_counts
29+
30+
31+
def calculate_gparam(
32+
updated_min_val: torch.Tensor,
33+
updated_max_val: torch.Tensor,
34+
scale_data: Optional[FloatArgs] = FP8_E4M3_DATA,
35+
quant_data: Optional[FloatArgs] = FP4_E2M1_DATA,
36+
dtype: Optional[torch.dtype] = torch.float32,
37+
):
38+
"""
39+
Generate a global scale for an entire tensor (input_tensor).
40+
Goal of the scale is to ensure that the quantization (local) scale
41+
falls into the approproiate dtype range.
42+
43+
E.g. for NVFP4, group (local) scales are in dtype FP8. The global_scale
44+
attempts to use the entire FP8 dtype range while mapping a per-group max
45+
to the FP4 max.
46+
"""
47+
min_vals = torch.min(updated_min_val, torch.zeros_like(updated_min_val))
48+
max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
49+
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
50+
global_scale = scale_data.max * quant_data.max / max_val_pos
51+
return global_scale.to(dtype)

src/llmcompressor/observers/min_max.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from compressed_tensors.utils import deprecated
77

88
from llmcompressor.observers.base import Observer
9+
from llmcompressor.observers.helpers import calculate_gparam
910

1011
__all__ = ["MinMaxObserver", "MovingAverageMinMaxObserver"]
1112

@@ -35,6 +36,8 @@ def calculate_qparams(
3536
reduce_dims: Optional[Tuple[int]] = None,
3637
tensor_id: Optional[Any] = None,
3738
global_scale: Optional[torch.Tensor] = None,
39+
should_calculate_gparam: bool = False,
40+
should_calculate_qparams: bool = True,
3841
) -> Tuple[torch.FloatTensor, torch.IntTensor]:
3942
"""
4043
Updates the observed min and max using a moving average smoothed by the
@@ -83,13 +86,24 @@ def calculate_qparams(
8386
self.min_val[tensor_id] = updated_min_val
8487
self.max_val[tensor_id] = updated_max_val
8588

86-
return calculate_qparams(
89+
if should_calculate_gparam:
90+
global_scale = calculate_gparam(
91+
updated_min_val=updated_max_val, updated_max_val=updated_max_val
92+
)
93+
if not should_calculate_qparams:
94+
return global_scale
95+
96+
scale, zero_point = calculate_qparams(
8797
min_vals=updated_min_val,
8898
max_vals=updated_max_val,
8999
quantization_args=self.quantization_args,
90100
global_scale=global_scale,
91101
)
92102

103+
if should_calculate_gparam:
104+
return scale, zero_point, global_scale
105+
return scale, zero_point
106+
93107
def get_qparams_along_dim(
94108
self,
95109
observed: torch.Tensor,

src/llmcompressor/observers/mse.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import FloatTensor, IntTensor, Tensor
77

88
from llmcompressor.observers.base import Observer
9+
from llmcompressor.observers.helpers import calculate_gparam
910

1011
__all__ = ["MovingAverageMSEObserver"]
1112

@@ -115,6 +116,8 @@ def calculate_qparams(
115116
reduce_dims: Optional[Tuple[int]] = None,
116117
tensor_id: Optional[Any] = None,
117118
global_scale: Optional[torch.Tensor] = None,
119+
should_calculate_gparam: bool = False,
120+
should_calculate_qparams: bool = True,
118121
) -> Tuple[FloatTensor, IntTensor]:
119122
"""
120123
Updates the mse-clipped min and max values of the observed tensor using
@@ -149,13 +152,24 @@ def calculate_qparams(
149152
self.min_val[tensor_id] = updated_min_val
150153
self.max_val[tensor_id] = updated_max_val
151154

152-
return calculate_qparams(
155+
if should_calculate_gparam:
156+
global_scale = calculate_gparam(
157+
updated_min_val=updated_max_val, updated_max_val=updated_max_val
158+
)
159+
if not should_calculate_qparams:
160+
return global_scale
161+
162+
scale, zero_point = calculate_qparams(
153163
min_vals=updated_min_val,
154164
max_vals=updated_max_val,
155165
quantization_args=self.quantization_args,
156166
global_scale=global_scale,
157167
)
158168

169+
if should_calculate_gparam:
170+
return scale, zero_point, global_scale
171+
return scale, zero_point
172+
159173
def get_qparams_along_dim(
160174
self,
161175
observed,

src/llmcompressor/transformers/compression/quantization_format.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ def infer_quantization_format(
6161
)
6262
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
6363

64-
if is_weight_only: # w4a16 and w8a16
65-
if (
66-
weight_args[0].num_bits == 4
67-
and weight_args[0].type == QuantizationType.FLOAT.value
68-
):
69-
return CompressionFormat.nvfp4_pack_quantized
64+
if (
65+
weight_args[0].num_bits == 4
66+
and weight_args[0].type == QuantizationType.FLOAT.value
67+
):
68+
return CompressionFormat.nvfp4_pack_quantized
7069

70+
if is_weight_only: # w4a16 and w8a16
7171
is_valid_pack = all(
7272
weight_arg.num_bits in [4, 8]
7373
and weight_arg.type == QuantizationType.INT.value

0 commit comments

Comments
 (0)