Skip to content

Commit d159a3c

Browse files
committed
add nvfpp-1
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent cf9843b commit d159a3c

File tree

4 files changed

+18
-4
lines changed

4 files changed

+18
-4
lines changed

examples/quantization_w4a4_fp4/llama3_example.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
MODEL_ID = "/data5/yliu7/HF_HOME/meta-llama/Llama-3.2-1B-Instruct/"
1010
# MODEL_ID = "meta-llama/Llama-3.3-70B-Instruct"
1111
scheme_name = "NVFP4"
12-
scheme_name = "MXFP4"
12+
1313
# scheme_name = "MXFP8"
1414
# scheme_name = "FP8"
15-
15+
scheme_name = "NVFPP_B32"
16+
# scheme_name = "MXFP4"
17+
# scheme_name = ""
1618
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + f"-{scheme_name}"
1719
SAVE_DIR = f"/data5/yliu7/HF_HOME/{SAVE_DIR}"
1820
print(f"Saving to {SAVE_DIR}")
@@ -85,6 +87,7 @@ def tokenize(sample):
8587
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
8688
model.device
8789
)
90+
print(f"=========== Starting generation =================")
8891
output = model.generate(input_ids, max_new_tokens=10)
8992

9093
print(tokenizer.decode(output[0]))

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
is_kv_cache_quant_scheme,
1515
is_mx,
1616
is_mxfp4,
17+
use_global_scales
1718
)
1819
from compressed_tensors.utils import align_module_device, update_parameter_data
1920
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
@@ -149,7 +150,10 @@ def update_weight_global_scale(module: Module):
149150
return
150151
weight_quant_args = getattr_chain(module, "quantization_scheme.weights")
151152
if is_mx(quantization_args=weight_quant_args):
152-
# MX schemes do not use global scale
153+
# MX schemes do not use global scale
154+
return
155+
if not use_global_scales(quantization_args=weight_quant_args):
156+
# global scales already in use
153157
return
154158

155159
call_observer(
@@ -209,6 +213,8 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
209213
calculate_qparams = False
210214
if is_fp4(quantization_args=quantization_args):
211215
calculate_gparam = True
216+
if not use_global_scales(quantization_args=quantization_args):
217+
calculate_gparam = False
212218

213219
call_observer(
214220
module=module,

src/llmcompressor/modifiers/utils/helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from compressed_tensors.quantization import QuantizationStrategy
5-
from compressed_tensors.quantization.utils import is_fp4, is_mxfp4
5+
from compressed_tensors.quantization.utils import is_fp4, is_mxfp4, use_global_scales
66
from compressed_tensors.utils import align_modules, update_parameter_data
77
from torch.nn import Linear, Module
88

@@ -52,6 +52,8 @@ def _valid_tensor_group_quant(layer_list: List[Linear]):
5252
return False
5353
if not is_fp4(quantization_args=weight_quant_args):
5454
return False
55+
if not use_global_scales(quantization_args=weight_quant_args):
56+
return False
5557
return True
5658

5759
if _is_attention_module(submodule):

src/llmcompressor/transformers/compression/quantization_format.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010
from compressed_tensors.quantization.utils import is_module_quantized
1111
from loguru import logger
12+
from compressed_tensors.quantization.utils.helpers import is_nvfpp_b32
1213

1314
__all__ = ["infer_and_set_per_module_quantization_format"]
1415

@@ -26,6 +27,8 @@ def _get_quant_compression_format(
2627
if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value:
2728
if weight_args.is_mx:
2829
return CompressionFormat.mxfp4_pack_quantized
30+
if is_nvfpp_b32(weight_args):
31+
return CompressionFormat.nvfpp_b32_pack_quantized
2932
return CompressionFormat.nvfp4_pack_quantized
3033

3134
if is_weight_only: # w4a16 and w8a16

0 commit comments

Comments
 (0)