File tree Expand file tree Collapse file tree 1 file changed +9
-7
lines changed
src/llmcompressor/modifiers/awq Expand file tree Collapse file tree 1 file changed +9
-7
lines changed Original file line number Diff line number Diff line change @@ -537,15 +537,17 @@ def _smooth(module):
537
537
v .batch_intermediates .clear ()
538
538
self ._assert_all_activations_consumed ()
539
539
540
- def _run_samples (self , module : Module ) -> List [ torch .Tensor ] :
540
+ def _get_flattened_output (self , module : Module ) -> torch .Tensor :
541
541
outputs = [
542
542
module (** batch_kwargs ) for batch_kwargs in self ._parent_args_cache [module ]
543
543
]
544
- return [
545
- # If Tuple, assume that first argument is the input
546
- output [0 ] if isinstance (output , Tuple ) else output
547
- for output in outputs
548
- ]
544
+ return torch .cat (
545
+ [
546
+ # If Tuple, assume that first argument is the input
547
+ (output [0 ] if isinstance (output , Tuple ) else output ).flatten ()
548
+ for output in outputs
549
+ ]
550
+ )
549
551
550
552
def _compute_best_scale (
551
553
self ,
@@ -614,7 +616,7 @@ def _compute_best_scale(
614
616
)
615
617
616
618
# W * X
617
- int_w_outputs = self ._run_samples (parent_module )
619
+ int_w_output = self ._get_flattened_output (parent_module )
618
620
619
621
# compute mean squared error (L2 norm)
620
622
loss = F .mse_loss (int_w_output , fp16_output ).item ()
You can’t perform that action at this time.
0 commit comments