Skip to content

Commit 0a146f8

Browse files
committed
hook with CompressedAttentionImpl
Signed-off-by: Kyle Sayers <[email protected]>
1 parent cec2914 commit 0a146f8

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,22 @@ def initialize_quantized_kv_cache(module: Module):
282282
setattr(module, "kv_cache", quantized_kv_cache)
283283

284284

285+
def initialize_attention_observers(module: Module):
286+
input_args = getattr_chain(module, "quantization_scheme.input_activations", None)
287+
if input_args is not None:
288+
initialize_observer(module, "q", input_args)
289+
initialize_observer(module, "k", input_args)
290+
initialize_observer(module, "v", input_args)
291+
292+
293+
def calibrate_attention(
294+
module: Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
295+
):
296+
calibrate_activations(module, value=query, base_name="q")
297+
calibrate_activations(module, value=key, base_name="k")
298+
calibrate_activations(module, value=value, base_name="v")
299+
300+
285301
def apply_calibration_status(module: Module):
286302
scheme = getattr(module, "quantization_scheme", None)
287303
if not scheme:

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def _initialize_observers(self, module: torch.nn.Module):
232232
# kv_cache activations. Within `apply_quantization_config`, the config is
233233
# modified to use attention output quantization if a kv_cache_scheme exists
234234
if is_attention and output:
235+
# initialize_attention_observers(module) # TODO: attnq
235236
initialize_quantized_kv_cache(module)
236237

237238
# output activations
@@ -240,6 +241,11 @@ def _initialize_observers(self, module: torch.nn.Module):
240241

241242
def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
242243
hooks = set()
244+
245+
# TODO: attnq
246+
# attention_impl = enable_compressed_attention(model)
247+
# hooks.add(self.register_hook(attention_impl, calibrate_attention, "calib"))
248+
243249
for module in model.modules():
244250
if not hasattr(module, "quantization_scheme"):
245251
continue

0 commit comments

Comments
 (0)