Skip to content

Commit 8293f18

Browse files
codeassist updates
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent d23090e commit 8293f18

File tree

1 file changed

+3
-1
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+3
-1
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,9 @@ def _compute_loss(
670670
fp16_output: torch.Tensor,
671671
int_w_output: torch.Tensor,
672672
) -> torch.Tensor:
673-
"""Compute MSE loss for each batch"""
673+
"""
674+
Compute MSE loss over the flattened output of all batches
675+
"""
674676
return (fp16_output - int_w_output).view(-1).float().pow(2).mean()
675677

676678

0 commit comments

Comments
 (0)