Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any, Dict, Optional, Tuple

import torch
from compressed_tensors.quantization import QuantizationStatus, is_attention_module
from compressed_tensors.quantization import (
KVCacheScaleType,
QuantizationStatus,
is_attention_module,
)
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
from compressed_tensors.utils.offload import is_module_offloaded, update_parameter_data
Expand Down Expand Up @@ -194,8 +198,10 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
"""
kv_cache = getattr(module, "kv_cache")
update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale")
update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale")
k_scale = kv_cache.k_scales[module.layer_idx]
v_scale = kv_cache.v_scales[module.layer_idx]
update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value)
update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)


def set_unset_kv_cache(module: Module):
Expand Down