Skip to content

Commit 908aa73

Browse files
get_flattenend_output
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent b0480d2 commit 908aa73

File tree

1 file changed

+9
-7
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+9
-7
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -537,15 +537,17 @@ def _smooth(module):
537537
v.batch_intermediates.clear()
538538
self._assert_all_activations_consumed()
539539

540-
def _run_samples(self, module: Module) -> List[torch.Tensor]:
540+
def _get_flattened_output(self, module: Module) -> torch.Tensor:
541541
outputs = [
542542
module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module]
543543
]
544-
return [
545-
# If Tuple, assume that first argument is the input
546-
output[0] if isinstance(output, Tuple) else output
547-
for output in outputs
548-
]
544+
return torch.cat(
545+
[
546+
# If Tuple, assume that first argument is the input
547+
(output[0] if isinstance(output, Tuple) else output).flatten()
548+
for output in outputs
549+
]
550+
)
549551

550552
def _compute_best_scale(
551553
self,
@@ -614,7 +616,7 @@ def _compute_best_scale(
614616
)
615617

616618
# W * X
617-
int_w_outputs = self._run_samples(parent_module)
619+
int_w_output = self._get_flattened_output(parent_module)
618620

619621
# compute mean squared error (L2 norm)
620622
loss = F.mse_loss(int_w_output, fp16_output).item()

0 commit comments

Comments
 (0)