Skip to content

Commit f4b7db6

Browse files
get_flattenend_output
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent c1c27d0 commit f4b7db6

File tree

1 file changed

+9
-7
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+9
-7
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -551,15 +551,17 @@ def _smooth(module):
551551
v.batch_intermediates.clear()
552552
self._assert_all_activations_consumed()
553553

554-
def _run_samples(self, module: Module) -> List[torch.Tensor]:
554+
def _get_flattened_output(self, module: Module) -> torch.Tensor:
555555
outputs = [
556556
module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module]
557557
]
558-
return [
559-
# If Tuple, assume that first argument is the input
560-
output[0] if isinstance(output, Tuple) else output
561-
for output in outputs
562-
]
558+
return torch.cat(
559+
[
560+
# If Tuple, assume that first argument is the input
561+
(output[0] if isinstance(output, Tuple) else output).flatten()
562+
for output in outputs
563+
]
564+
)
563565

564566
def _compute_best_scale(
565567
self,
@@ -628,7 +630,7 @@ def _compute_best_scale(
628630
)
629631

630632
# W * X
631-
int_w_outputs = self._run_samples(parent_module)
633+
int_w_output = self._get_flattened_output(parent_module)
632634

633635
# compute mean squared error (L2 norm)
634636
loss = F.mse_loss(int_w_output, fp16_output).item()

0 commit comments

Comments
 (0)