File tree Expand file tree Collapse file tree 2 files changed +22
-0
lines changed
src/llmcompressor/modifiers/quantization Expand file tree Collapse file tree 2 files changed +22
-0
lines changed Original file line number Diff line number Diff line change @@ -282,6 +282,22 @@ def initialize_quantized_kv_cache(module: Module):
282
282
setattr (module , "kv_cache" , quantized_kv_cache )
283
283
284
284
285
+ def initialize_attention_observers (module : Module ):
286
+ input_args = getattr_chain (module , "quantization_scheme.input_activations" , None )
287
+ if input_args is not None :
288
+ initialize_observer (module , "q" , input_args )
289
+ initialize_observer (module , "k" , input_args )
290
+ initialize_observer (module , "v" , input_args )
291
+
292
+
293
+ def calibrate_attention (
294
+ module : Module , query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
295
+ ):
296
+ calibrate_activations (module , value = query , base_name = "q" )
297
+ calibrate_activations (module , value = key , base_name = "k" )
298
+ calibrate_activations (module , value = value , base_name = "v" )
299
+
300
+
285
301
def apply_calibration_status (module : Module ):
286
302
scheme = getattr (module , "quantization_scheme" , None )
287
303
if not scheme :
Original file line number Diff line number Diff line change @@ -232,6 +232,7 @@ def _initialize_observers(self, module: torch.nn.Module):
232
232
# kv_cache activations. Within `apply_quantization_config`, the config is
233
233
# modified to use attention output quantization if a kv_cache_scheme exists
234
234
if is_attention and output :
235
+ # initialize_attention_observers(module) # TODO: attnq
235
236
initialize_quantized_kv_cache (module )
236
237
237
238
# output activations
@@ -240,6 +241,11 @@ def _initialize_observers(self, module: torch.nn.Module):
240
241
241
242
def _initialize_hooks (self , model : torch .nn .Module ) -> Set [RemovableHandle ]:
242
243
hooks = set ()
244
+
245
+ # TODO: attnq
246
+ # attention_impl = enable_compressed_attention(model)
247
+ # hooks.add(self.register_hook(attention_impl, calibrate_attention, "calib"))
248
+
243
249
for module in model .modules ():
244
250
if not hasattr (module , "quantization_scheme" ):
245
251
continue
You can’t perform that action at this time.
0 commit comments