diff --git a/src/llmcompressor/modifiers/quantization/cache.py b/src/llmcompressor/modifiers/quantization/cache.py index dd3640dda..f82432b71 100644 --- a/src/llmcompressor/modifiers/quantization/cache.py +++ b/src/llmcompressor/modifiers/quantization/cache.py @@ -94,6 +94,16 @@ def update( _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer) _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer) + kv_states_dim = key_states.dim() + if kv_states_dim == 4: + # reshape for per channel scenario + num_heads = key_states.shape[1] + head_dim = key_states.shape[-1] + # from [batch_size, num_heads, seq_len - residual_length, head_dim] + # to [batch_size, seq_len - residual_length, num_heads * head_dim] + key_states = key_states.transpose(1, 2).flatten(2) + value_states = value_states.transpose(1, 2).flatten(2) + q_key_states = self._quantize( key_states.contiguous(), KVCacheScaleType.KEY, layer_idx ) @@ -106,6 +116,19 @@ def update( q_value_states, KVCacheScaleType.VALUE, layer_idx ) + if kv_states_dim == 4: + # reshape for per channel scenario + # from [batch_size, seq_len - residual_length, num_heads * head_dim] + # to [batch_size, num_heads, seq_len - residual_length, head_dim] + qdq_key_states = qdq_key_states.view( + qdq_key_states.shape[0], qdq_key_states.shape[1], + num_heads, head_dim + ).transpose(1, 2).contiguous() + qdq_value_states = qdq_value_states.view( + qdq_value_states.shape[0], qdq_value_states.shape[1], + num_heads, head_dim + ).transpose(1, 2).contiguous() + keys_to_return, values_to_return = qdq_key_states, qdq_value_states return keys_to_return, values_to_return @@ -155,8 +178,8 @@ def _quantize(self, tensor, kv_type, layer_idx): zps = self.v_zps scale, zp = observer(tensor) - _pad_and_append_at_idx_(scales, layer_idx, scale) - _pad_and_append_at_idx_(zps, layer_idx, zp) + _pad_and_append_at_idx_(scales, layer_idx, scale.squeeze()) + _pad_and_append_at_idx_(zps, layer_idx, zp.squeeze()) q_tensor = quantize( x=tensor, diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index b10a4cb31..9836ee78c 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -256,6 +256,9 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te kv_cache = getattr(module, "kv_cache") k_scale = kv_cache.k_scales[module.layer_idx] v_scale = kv_cache.v_scales[module.layer_idx] + if kv_cache.quantization_args.strategy == QuantizationStrategy.CHANNEL: + k_scale = k_scale.unsqueeze(-1) + v_scale = v_scale.unsqueeze(-1) update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value) update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 3ee446cf3..45417ff6e 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -181,8 +181,15 @@ def get_qparams( self._zero_point[:, group_index] = zero_point.squeeze(1) elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: - # assume observed is transposed, because its the output, hence use dim 0 - self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) + # 1. dim=2 scenario: in kv cache quant scenario which is + # [batch_size, seq_len - residual_length, num_heads * head_dim] + # 2. dim=0 scenario: assume observed is transposed, + # because its the output, hence use dim 0 + dim = 2 if observed.dim() == 3 else 0 + self._scale, self._zero_point = self.get_qparams_along_dim( + observed, + dim + ) elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: # use dim 1, assume the obsersed.shape = [batch, token, hidden]