File tree Expand file tree Collapse file tree 1 file changed +9
-7
lines changed
src/llmcompressor/modifiers/awq Expand file tree Collapse file tree 1 file changed +9
-7
lines changed Original file line number Diff line number Diff line change @@ -551,15 +551,17 @@ def _smooth(module):
551
551
v .batch_intermediates .clear ()
552
552
self ._assert_all_activations_consumed ()
553
553
554
- def _run_samples (self , module : Module ) -> List [ torch .Tensor ] :
554
+ def _get_flattened_output (self , module : Module ) -> torch .Tensor :
555
555
outputs = [
556
556
module (** batch_kwargs ) for batch_kwargs in self ._parent_args_cache [module ]
557
557
]
558
- return [
559
- # If Tuple, assume that first argument is the input
560
- output [0 ] if isinstance (output , Tuple ) else output
561
- for output in outputs
562
- ]
558
+ return torch .cat (
559
+ [
560
+ # If Tuple, assume that first argument is the input
561
+ (output [0 ] if isinstance (output , Tuple ) else output ).flatten ()
562
+ for output in outputs
563
+ ]
564
+ )
563
565
564
566
def _compute_best_scale (
565
567
self ,
@@ -628,7 +630,7 @@ def _compute_best_scale(
628
630
)
629
631
630
632
# W * X
631
- int_w_outputs = self ._run_samples (parent_module )
633
+ int_w_output = self ._get_flattened_output (parent_module )
632
634
633
635
# compute mean squared error (L2 norm)
634
636
loss = F .mse_loss (int_w_output , fp16_output ).item ()
You can’t perform that action at this time.
0 commit comments