Skip to content

Commit a7bf319

Browse files
committed
update
1 parent 9435d81 commit a7bf319

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,31 @@ def update_weight_zp_scale(module: Module):
120120

121121
if module.quantization_scheme.weights is not None:
122122
# set weight scale and zero_point up front, calibration data doesn't affect it
123+
124+
transform_data = getattr(module, "transform_data", None)
125+
if transform_data is not None:
126+
# order that the transforms were added to match the order they should be applied
127+
untransformed_weight = module.weight.data.clone()
128+
for transform_name, transform_values in transform_data.data.items():
129+
transform = getattr(module, transform_name)
130+
apply = transform_values.get("apply")
131+
call_args = transform_values.get("call_args")
132+
if call_args:
133+
transformed_weight = apply(
134+
input_tensor=module.weight, transform=transform, **call_args
135+
)
136+
else:
137+
transformed_weight = apply(
138+
input_tensor=module.weight, transform=transform
139+
)
140+
module.weight.data.copy_(transformed_weight)
141+
123142
call_observer(module=module, base_name="weight")
124143

144+
# TODO: what do we do here?
145+
if transform_data is not None:
146+
module.weight.data.copy_(untransformed_weight)
147+
125148

126149
def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
127150
"""

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
is_preset_scheme,
1111
preset_name_to_scheme,
1212
)
13+
from compressed_tensors.transforms.transform_config import TransformationConfig
1314
from loguru import logger
1415
from pydantic import Field, field_validator
1516
from torch.nn import Module
@@ -74,6 +75,7 @@ class QuantizationModifier(Modifier):
7475
"""
7576

7677
config_groups: Optional[Dict[str, QuantizationScheme]] = None
78+
transforms_config: Optional[TransformationConfig] = None
7779
ignore: List[str] = Field(default_factory=list)
7880
targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"])
7981
scheme: Optional[Union[str, Dict[str, Any]]] = None
@@ -210,7 +212,9 @@ def _check_calibration_data(self, config: QuantizationConfig):
210212
def _apply_modifier_to_model(self, model: Module):
211213
modifier_as_config = self.create_init_config()
212214
# Add step to attach kv_cache to the model, if present within the config
213-
apply_quantization_config(model, modifier_as_config)
215+
apply_quantization_config(
216+
model, modifier_as_config, transforms_config=self.transforms_config
217+
)
214218
model.apply(set_unset_kv_cache)
215219
return modifier_as_config
216220

0 commit comments

Comments
 (0)