From 9245fa5f0b1884757929a4f7765971dc62b12290 Mon Sep 17 00:00:00 2001 From: evian Date: Sat, 19 Jul 2025 18:04:04 +0800 Subject: [PATCH 1/3] [KV Cache] support kv cache int8 per channel quantization Signed-off-by: evian --- .../modifiers/quantization/cache.py | 20 +++++++++++++++++-- src/llmcompressor/observers/base.py | 7 +++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/cache.py b/src/llmcompressor/modifiers/quantization/cache.py index dd3640dda..4d99ae4e7 100644 --- a/src/llmcompressor/modifiers/quantization/cache.py +++ b/src/llmcompressor/modifiers/quantization/cache.py @@ -94,6 +94,14 @@ def update( _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer) _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer) + # reshape for per channel scenario + num_heads = key_states.shape[1] + head_dim = key_states.shape[-1] + # from [batch_size, num_heads, seq_len - residual_length, head_dim] + # to [batch_size, seq_len - residual_length, num_heads * head_dim] + key_states = key_states.transpose(1, 2).flatten(2) + value_states = value_states.transpose(1, 2).flatten(2) + q_key_states = self._quantize( key_states.contiguous(), KVCacheScaleType.KEY, layer_idx ) @@ -106,6 +114,14 @@ def update( q_value_states, KVCacheScaleType.VALUE, layer_idx ) + # reshape for per channel scenario + # from [batch_size, seq_len - residual_length, num_heads * head_dim] + # to [batch_size, num_heads, seq_len - residual_length, head_dim] + qdq_key_states = qdq_key_states.view( + qdq_key_states.shape[0], qdq_key_states.shape[1], num_heads, head_dim).transpose(1, 2) + qdq_value_states = qdq_value_states.view( + qdq_value_states.shape[0], qdq_value_states.shape[1], num_heads, head_dim).transpose(1, 2) + keys_to_return, values_to_return = qdq_key_states, qdq_value_states return keys_to_return, values_to_return @@ -155,8 +171,8 @@ def _quantize(self, tensor, kv_type, layer_idx): zps = self.v_zps scale, zp = observer(tensor) - _pad_and_append_at_idx_(scales, layer_idx, scale) - _pad_and_append_at_idx_(zps, layer_idx, zp) + _pad_and_append_at_idx_(scales, layer_idx, scale.squeeze()) + _pad_and_append_at_idx_(zps, layer_idx, zp.squeeze()) q_tensor = quantize( x=tensor, diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 3ee446cf3..e82ffc899 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -181,8 +181,11 @@ def get_qparams( self._zero_point[:, group_index] = zero_point.squeeze(1) elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: - # assume observed is transposed, because its the output, hence use dim 0 - self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) + # 1. dim=2 scenario: in kv cache quant scenario which is + # [batch_size, seq_len - residual_length, num_heads * head_dim] + # 2. dim=0 scenario: assume observed is transposed, because its the output, hence use dim 0 + dim = 2 if observed.dim() == 3 else 0 + self._scale, self._zero_point = self.get_qparams_along_dim(observed, dim) elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: # use dim 1, assume the obsersed.shape = [batch, token, hidden] From 38200b31d48ecd2dd49ff6f432f28c59f6c9464f Mon Sep 17 00:00:00 2001 From: evian Date: Wed, 23 Jul 2025 15:43:40 +0800 Subject: [PATCH 2/3] [KV Cache] fix ci Signed-off-by: evian --- .../modifiers/quantization/cache.py | 34 +++++++++++-------- src/llmcompressor/observers/base.py | 8 +++-- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/cache.py b/src/llmcompressor/modifiers/quantization/cache.py index 4d99ae4e7..5f441cf30 100644 --- a/src/llmcompressor/modifiers/quantization/cache.py +++ b/src/llmcompressor/modifiers/quantization/cache.py @@ -94,13 +94,14 @@ def update( _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer) _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer) - # reshape for per channel scenario - num_heads = key_states.shape[1] - head_dim = key_states.shape[-1] - # from [batch_size, num_heads, seq_len - residual_length, head_dim] - # to [batch_size, seq_len - residual_length, num_heads * head_dim] - key_states = key_states.transpose(1, 2).flatten(2) - value_states = value_states.transpose(1, 2).flatten(2) + if key_states.dim() == 4: + # reshape for per channel scenario + num_heads = key_states.shape[1] + head_dim = key_states.shape[-1] + # from [batch_size, num_heads, seq_len - residual_length, head_dim] + # to [batch_size, seq_len - residual_length, num_heads * head_dim] + key_states = key_states.transpose(1, 2).flatten(2) + value_states = value_states.transpose(1, 2).flatten(2) q_key_states = self._quantize( key_states.contiguous(), KVCacheScaleType.KEY, layer_idx @@ -114,13 +115,18 @@ def update( q_value_states, KVCacheScaleType.VALUE, layer_idx ) - # reshape for per channel scenario - # from [batch_size, seq_len - residual_length, num_heads * head_dim] - # to [batch_size, num_heads, seq_len - residual_length, head_dim] - qdq_key_states = qdq_key_states.view( - qdq_key_states.shape[0], qdq_key_states.shape[1], num_heads, head_dim).transpose(1, 2) - qdq_value_states = qdq_value_states.view( - qdq_value_states.shape[0], qdq_value_states.shape[1], num_heads, head_dim).transpose(1, 2) + if key_states.dim() == 4: + # reshape for per channel scenario + # from [batch_size, seq_len - residual_length, num_heads * head_dim] + # to [batch_size, num_heads, seq_len - residual_length, head_dim] + qdq_key_states = qdq_key_states.view( + qdq_key_states.shape[0], qdq_key_states.shape[1], + num_heads, head_dim + ).transpose(1, 2).contiguous() + qdq_value_states = qdq_value_states.view( + qdq_value_states.shape[0], qdq_value_states.shape[1], + num_heads, head_dim + ).transpose(1, 2).contiguous() keys_to_return, values_to_return = qdq_key_states, qdq_value_states diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index e82ffc899..45417ff6e 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -183,9 +183,13 @@ def get_qparams( elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: # 1. dim=2 scenario: in kv cache quant scenario which is # [batch_size, seq_len - residual_length, num_heads * head_dim] - # 2. dim=0 scenario: assume observed is transposed, because its the output, hence use dim 0 + # 2. dim=0 scenario: assume observed is transposed, + # because its the output, hence use dim 0 dim = 2 if observed.dim() == 3 else 0 - self._scale, self._zero_point = self.get_qparams_along_dim(observed, dim) + self._scale, self._zero_point = self.get_qparams_along_dim( + observed, + dim + ) elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: # use dim 1, assume the obsersed.shape = [batch, token, hidden] From a8faf8dce72f5a82634fcb20c5b7a00bb034fa2d Mon Sep 17 00:00:00 2001 From: evian Date: Mon, 4 Aug 2025 20:53:04 +0800 Subject: [PATCH 3/3] [KV Cache] fix per channel shape Signed-off-by: evian --- src/llmcompressor/modifiers/quantization/cache.py | 5 +++-- src/llmcompressor/modifiers/quantization/calibration.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/cache.py b/src/llmcompressor/modifiers/quantization/cache.py index 5f441cf30..f82432b71 100644 --- a/src/llmcompressor/modifiers/quantization/cache.py +++ b/src/llmcompressor/modifiers/quantization/cache.py @@ -94,7 +94,8 @@ def update( _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer) _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer) - if key_states.dim() == 4: + kv_states_dim = key_states.dim() + if kv_states_dim == 4: # reshape for per channel scenario num_heads = key_states.shape[1] head_dim = key_states.shape[-1] @@ -115,7 +116,7 @@ def update( q_value_states, KVCacheScaleType.VALUE, layer_idx ) - if key_states.dim() == 4: + if kv_states_dim == 4: # reshape for per channel scenario # from [batch_size, seq_len - residual_length, num_heads * head_dim] # to [batch_size, num_heads, seq_len - residual_length, head_dim] diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index b10a4cb31..9836ee78c 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -256,6 +256,9 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te kv_cache = getattr(module, "kv_cache") k_scale = kv_cache.k_scales[module.layer_idx] v_scale = kv_cache.v_scales[module.layer_idx] + if kv_cache.quantization_args.strategy == QuantizationStrategy.CHANNEL: + k_scale = k_scale.unsqueeze(-1) + v_scale = v_scale.unsqueeze(-1) update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value) update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)