22from typing import Any , Dict , List , Optional , Union
33
44import torch
5- from compressed_tensors .utils import align_module_device , update_offload_parameter
5+ from compressed_tensors .utils import (
6+ align_module_device ,
7+ get_execution_device ,
8+ update_offload_parameter ,
9+ )
610from loguru import logger
711from pydantic import ConfigDict
812from torch .nn import Module
1115from llmcompressor .core import State
1216from llmcompressor .modifiers import Modifier
1317from llmcompressor .modifiers .utils .pytorch_helpers import run_calibration_forward
14- from llmcompressor .pytorch .utils import tensor_forward_with_input_args
1518from llmcompressor .utils .fsdp .helpers import get_fsdp_parent
1619from llmcompressor .utils .helpers import calibration_forward_context
1720from llmcompressor .utils .pytorch .module import (
@@ -217,7 +220,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
217220 self ._resolved_mappings = resolved_mappings
218221 return
219222
220- def _setup_scale_hooks (self ):
223+ def _setup_scale_hooks (self ) -> None :
221224 """
222225 Attach a forward hook to each activation we want to smooth. This allows us to
223226 calculate the dynamic range during calibration
@@ -243,7 +246,7 @@ def hook_fn(module, inp, out):
243246 self .register_hook (layer , create_hook_fn (name ), "forward" )
244247
245248 @torch .no_grad ()
246- def _calibrate (self , model : Module , calibration_dataloader : List ):
249+ def _calibrate (self , model : Module , calibration_dataloader : List ) -> None :
247250 """
248251 Catch the output dynamic ranges of each layer that will be smoothed by running
249252 forward passes with calibration_dataloader
@@ -264,7 +267,7 @@ def _calibrate(self, model: Module, calibration_dataloader: List):
264267 calibration_dataloader ,
265268 )
266269
267- def _concat_collected_activations (self ):
270+ def _concat_collected_activations (self ) -> None :
268271 """
269272 Concatenate the collected activation values from each forward pass into a single
270273 tensor for each layer
@@ -277,7 +280,7 @@ def _concat_collected_activations(self):
277280 self ._scales [name ] = torch .cat (self ._scales [name ], dim = 0 )
278281
279282 @torch .no_grad ()
280- def _apply_smoothing (self , model : Module ):
283+ def _apply_smoothing (self , model : Module ) -> None :
281284 """
282285 Calculate the best scaling factors for each layer to smooth activations and
283286 apply the scaling factors to the weights of the next layer to offset the
@@ -484,7 +487,7 @@ def _compute_loss(
484487 fp16_output : torch .Tensor ,
485488 int_w_output : torch .Tensor ,
486489 device : torch .device ,
487- ):
490+ ) -> torch . Tensor :
488491 loss = 0.0
489492 fp16_output_flat = fp16_output .view (- 1 )
490493 int_w_output_flat = int_w_output .view (- 1 )
@@ -579,7 +582,7 @@ def _forward_input_with_kwargs(
579582 module : Module ,
580583 inputs : torch .Tensor ,
581584 input_kwargs : Optional [Dict [str , Any ]] = None ,
582- ):
585+ ) -> torch . Tensor :
583586 """
584587 Forward pass with input arguments
585588
@@ -590,43 +593,44 @@ def _forward_input_with_kwargs(
590593 """
591594 kwargs = input_kwargs or self ._module_kwargs
592595 kwargs = _sanitize_kwargs (kwargs , module )
593- return tensor_forward_with_input_args (
594- module = module ,
595- inputs = inputs ,
596- input_kwargs = kwargs ,
597- )[0 ]
596+
597+ inputs = inputs .to (get_execution_device (module ))
598+
599+ return module (inputs , ** kwargs )[0 ]
598600
599601
600- def _sanitize_kwargs (inputs_kwargs , module ) :
602+ def _sanitize_kwargs (input_kwargs : Dict [ str , Any ], module : Module ) -> Dict [ str , Any ] :
601603 """
602- Remove the arguments that are not supported in the module's
603- forward pass to avoid breaking behaviour between different versions
604- of transformers.
604+ Sanitize input keyword arguments to match the module's forward method signature,
605+ excluding `use_cache` which is not desired to be passed into module.
605606
606607 Args:
607608 inputs_kwargs (`dict`):
608609 The input dictionary to pass to the model layer
609610 module (`torch.nn.Module`):
610611 Target module to quantize.
611612 """
613+
612614 params = inspect .signature (module .forward ).parameters
613- sanitized_kwargs = {}
614- for k , v in inputs_kwargs . items ():
615- if k in params and k != "use_cache" :
616- sanitized_kwargs [ k ] = v
617- # In case forward pass has optional dependencies that don't default to None.
615+
616+ # Filter out any kwargs not in module.forward signature
617+ sanitized_kwargs = { k : v for k , v in input_kwargs . items () if k in params }
618+
619+ # Edge Case: forward pass has optional dependencies that don't default to None.
618620 # This is the case for `LlamaAttention.forward` which has input
619621 # `attention_mask: Optional[torch.Tensor],` (with no `= None` default)
620622 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L246
621623 for k , v in params .items ():
622624 if (
623625 k not in sanitized_kwargs
624- and k != "use_cache"
625626 and v .default is inspect .Parameter .empty
626627 and str (v .annotation ).startswith ("typing.Optional" )
627628 ):
628629 sanitized_kwargs [k ] = None
629630
631+ # Exclude `use_cache` entirely
632+ sanitized_kwargs .pop ("use_cache" , None )
633+
630634 return sanitized_kwargs
631635
632636
0 commit comments