Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/observers.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ from llmcompressor.observers import Observer
from compressed_tensors.quantization.quant_args import QuantizationArgs

args = QuantizationArgs(num_bits=4, strategy="group", group_size=128)
observer = Observer.load_from_registry("minmax", quantization_args=args)
observer = Observer.load_from_registry(
"minmax",
base_name="weight",
quantization_args=args,
)

x = torch.randn(64, 512)
scale, zero_point = observer(x)
Expand Down
10 changes: 6 additions & 4 deletions src/llmcompressor/modifiers/quantization/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,15 @@ def update(
"""

if len(self.k_observers) <= layer_idx:
k_observer_name = self.quantization_args.observer
k_observer = Observer.load_from_registry(
k_observer_name, quantization_args=self.quantization_args
self.quantization_args.observer,
base_name="k",
args=self.quantization_args,
)
v_observer_name = self.quantization_args.observer
v_observer = Observer.load_from_registry(
v_observer_name, quantization_args=self.quantization_args
self.quantization_args.observer,
base_name="v",
args=self.quantization_args,
)

# NOTE: User may ignore some layers in configuration,
Expand Down
75 changes: 19 additions & 56 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from compressed_tensors.quantization import (
DynamicType,
KVCacheScaleType,
QuantizationArgs,
QuantizationScheme,
QuantizationStatus,
QuantizationStrategy,
Expand All @@ -19,12 +20,6 @@
from llmcompressor.observers import Observer
from llmcompressor.utils.helpers import getattr_chain

DEFAULT_MAXSHRINK = 0.20
DEFAULT_PATIENCE = 5
DEFAULT_AVERAGING_CONSTANT = 0.01
DEFAULT_GRID = 100.0
DEFAULT_NORM = 2.4

__all__ = [
"initialize_observer",
"update_weight_zp_scale",
Expand Down Expand Up @@ -54,31 +49,19 @@ def initialize_observer(
:param base_name: str used to name the observer attribute

"""

arg_name = "weights" if base_name == "weight" else f"{base_name}_activations"
quantization_scheme = getattr(module, "quantization_scheme", None)
if not quantization_scheme:
# no quantization scheme nothing to do
return

quantization_args = getattr(quantization_scheme, arg_name, None)
# dont need observers for dynamic
if quantization_args is not None and quantization_args.dynamic in (
False,
DynamicType.LOCAL,
):
observer_kwargs = quantization_args.observer_kwargs or {}
if base_name == "weight":
arg_name = "weights"
elif base_name == "output":
arg_name = "output_activations"
else: # input, q, k, v
arg_name = "input_activations"

args: QuantizationArgs = getattr_chain(
module, f"quantization_scheme.{arg_name}", None
)
if args is not None and args.dynamic is not True:
observer = Observer.load_from_registry(
quantization_args.observer,
quantization_args=quantization_args,
averaging_constant=observer_kwargs.get(
"averaging_constant", DEFAULT_AVERAGING_CONSTANT
),
# used by mse observer only, will be ignored by minmax observer
maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK),
patience=observer_kwargs.get("patience", DEFAULT_PATIENCE),
grid=observer_kwargs.get("grid", DEFAULT_GRID),
norm=observer_kwargs.get("norm", DEFAULT_NORM),
args.observer, base_name=base_name, args=args, module=module
)
module.register_module(f"{base_name}_observer", observer)

Expand All @@ -100,36 +83,17 @@ def call_observer(
base_name is "weight", then the module's weight tensor will be used
"""
with align_module_device(module):
if base_name == "weight":
value = module.weight
g_idx = getattr(module, "weight_g_idx", None)
elif value is not None:
g_idx = None
else:
raise ValueError(
"Must provide a value to observe if not using weight observer"
)

observer = getattr(module, f"{base_name}_observer")
value = module.weight if base_name == "weight" else value
observer: Observer = getattr(module, f"{base_name}_observer")

if should_calculate_gparam:
global_scale = observer(
value,
should_calculate_gparam=True,
)
global_scale = observer.get_global_scale(value)
update_offload_parameter(module, f"{base_name}_global_scale", global_scale)
else:
global_scale = getattr(module, f"{base_name}_global_scale", None)

if should_calculate_qparams:
updated_scale, updated_zero_point = observer(
value, g_idx=g_idx, global_scale=global_scale
)
# register or update scale & zero_point parameters (supports block shapes)
scale_name = f"{base_name}_scale"
zp_name = f"{base_name}_zero_point"
update_offload_parameter(module, scale_name, updated_scale)
update_offload_parameter(module, zp_name, updated_zero_point)
scale, zero_point = observer(value)
update_offload_parameter(module, f"{base_name}_scale", scale)
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)


def update_weight_global_scale(module: Module):
Expand All @@ -148,7 +112,6 @@ def update_weight_global_scale(module: Module):
should_calculate_gparam=True,
should_calculate_qparams=False,
)
module.weight_observer.reset()


def update_weight_zp_scale(module: Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ def quantize_weight(

# create observer for calculating quantization parameters
observer = Observer.load_from_registry(
quant_args.observer,
quantization_args=quant_args,
"minmax",
base_name="weight",
args=quant_args,
module=module,
averaging_constant=1.0, # ignore moving average
)

Expand Down
Loading
Loading