Skip to content

Commit a8faf8d

Browse files
author
evian
committed
[KV Cache] fix per channel shape
Signed-off-by: evian <[email protected]>
1 parent 38200b3 commit a8faf8d

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/llmcompressor/modifiers/quantization/cache.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def update(
9494
_pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer)
9595
_pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer)
9696

97-
if key_states.dim() == 4:
97+
kv_states_dim = key_states.dim()
98+
if kv_states_dim == 4:
9899
# reshape for per channel scenario
99100
num_heads = key_states.shape[1]
100101
head_dim = key_states.shape[-1]
@@ -115,7 +116,7 @@ def update(
115116
q_value_states, KVCacheScaleType.VALUE, layer_idx
116117
)
117118

118-
if key_states.dim() == 4:
119+
if kv_states_dim == 4:
119120
# reshape for per channel scenario
120121
# from [batch_size, seq_len - residual_length, num_heads * head_dim]
121122
# to [batch_size, num_heads, seq_len - residual_length, head_dim]

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
256256
kv_cache = getattr(module, "kv_cache")
257257
k_scale = kv_cache.k_scales[module.layer_idx]
258258
v_scale = kv_cache.v_scales[module.layer_idx]
259+
if kv_cache.quantization_args.strategy == QuantizationStrategy.CHANNEL:
260+
k_scale = k_scale.unsqueeze(-1)
261+
v_scale = v_scale.unsqueeze(-1)
259262
update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value)
260263
update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)
261264

0 commit comments

Comments
 (0)