Skip to content

Commit 7fb58df

Browse files
[AWQ] Support for Calibration Datasets of varying feature dimension (#1536)
SUMMARY: AWQModifier currently expects all batches to have the same feature dimension. Users are reporting this causes errors for vision-language datasets. This PR adds support to AWQModifier for calibration dataset batches with varying feature dimension. Rather than concatentating all outputs to a single torch tensor, the lists are retained and passed through one by one to compute loss. This removes the need for chunk memory configuration, and the logic around that calculation. * Resolves #1524 TEST PLAN: - [x] Re-ran for `"meta-llama/Llama-3.2-3B-Instruct"`, wikitext PPL of 13.30 is better than the 14.08 previously achieved, because the dataset is slightly different now. - [x] Also confirmed [user-provided code in #1524](#1524 (comment)) can be run with smaller `"Qwen/Qwen2.5-VL-7B-Instruct"` model up until it tries to access jpg file on user's local machine. --------- Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 1648528 commit 7fb58df

File tree

1 file changed

+23
-38
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+23
-38
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ class AWQModifier(Modifier, QuantizationMixin):
113113
requirements but requires more time to move data between cpu and execution
114114
device. Defaults to None, so cached args are not offloaded. Consider setting
115115
to torch.device("cpu") if you are encountering OOM errors
116-
:param max_chunk_memory: maximum memory to use for each chunk of input activations
117116
:param duo_scaling: whether to use duo scaling, which uses both input activations
118117
and weights to determine the scaling factor
119118
"""
@@ -125,7 +124,6 @@ class AWQModifier(Modifier, QuantizationMixin):
125124
sequential_targets: Union[str, List[str], None] = None
126125
mappings: Optional[List[AWQMapping]] = None
127126
offload_device: Optional[torch.device] = None
128-
max_chunk_memory: int = 1024 * 1024 * 1024
129127
duo_scaling: bool = True
130128

131129
# Private vars set during validation
@@ -476,8 +474,8 @@ def _apply_smoothing(self, model: Module) -> None:
476474
with calibration_forward_context(model), HooksMixin.disable_hooks():
477475
# [STEP 3]: Compute output of module
478476
# could cache from hook, rather than recomputing here
479-
fp16_output = self._run_samples(parent_module)
480-
if fp16_output.numel() == 0:
477+
fp16_outputs = self._run_samples(parent_module)
478+
if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs):
481479
logger.info(
482480
f"Skipping smooth_layer {mapping.smooth_name}, no activations "
483481
"found to scale. This can occasionally occur in MoE models "
@@ -490,7 +488,7 @@ def _apply_smoothing(self, model: Module) -> None:
490488

491489
# [STEP 4]: Compute loss
492490
best_scales = self._compute_best_scale(
493-
x_mean, w_mean, parent_module, balance_layers, fp16_output
491+
x_mean, w_mean, parent_module, balance_layers, fp16_outputs
494492
)
495493

496494
@torch.no_grad()
@@ -543,28 +541,25 @@ def smooth(module):
543541
v.batch_intermediates.clear()
544542
self._assert_all_activations_consumed()
545543

546-
def _run_samples(self, module: Module) -> torch.Tensor:
544+
def _run_samples(self, module: Module) -> List[torch.Tensor]:
547545
with align_module_device(module):
548546
outputs = [
549547
module(**batch_kwargs)
550548
for batch_kwargs in self._parent_args_cache[module]
551549
]
552-
return torch.cat(
553-
[
554-
# If Tuple, assume that first argument is the input
555-
output[0] if isinstance(output, Tuple) else output
556-
for output in outputs
557-
],
558-
dim=0,
559-
)
550+
return [
551+
# If Tuple, assume that first argument is the input
552+
output[0] if isinstance(output, Tuple) else output
553+
for output in outputs
554+
]
560555

561556
def _compute_best_scale(
562557
self,
563558
x_mean: torch.Tensor,
564559
w_mean: torch.Tensor,
565560
parent_module: torch.nn.Module,
566561
linears2scale: List[torch.nn.Linear],
567-
fp16_output: torch.Tensor,
562+
fp16_outputs: List[torch.Tensor],
568563
) -> torch.Tensor:
569564
"""
570565
Compute loss and select best scales
@@ -623,10 +618,10 @@ def _compute_best_scale(
623618

624619
# W * X
625620
with HooksMixin.disable_hooks():
626-
int_w_output = self._run_samples(parent_module)
621+
int_w_outputs = self._run_samples(parent_module)
627622

628623
# compute mean squared error (L2 norm)
629-
loss = self._compute_loss(fp16_output, int_w_output, device)
624+
loss = self._compute_loss(fp16_outputs, int_w_outputs, device)
630625

631626
history.append(loss)
632627
if loss < best_error:
@@ -648,35 +643,25 @@ def _compute_best_scale(
648643
@torch.no_grad()
649644
def _compute_loss(
650645
self,
651-
fp16_output: torch.Tensor,
652-
int_w_output: torch.Tensor,
646+
fp16_outputs: List[torch.Tensor],
647+
int_w_outputs: List[torch.Tensor],
653648
device: torch.device,
654649
) -> torch.Tensor:
655650
loss = 0.0
656-
fp16_output_flat = fp16_output.view(-1)
657-
int_w_output_flat = int_w_output.view(-1)
658-
num_elements = fp16_output_flat.size(0)
659-
element_size_bytes = fp16_output.element_size()
660-
661-
# Calculate chunk size dynamically based on max_chunk_memory
662-
# Divide the max_chunk_memory by twice the element size
663-
chunk_size = self.max_chunk_memory // (element_size_bytes * 2)
664-
chunk_size = min(chunk_size, num_elements)
665-
666-
# Split the computation into chunks
667-
fp16_chunks = torch.split(fp16_output_flat, chunk_size)
668-
int_w_chunks = torch.split(int_w_output_flat, chunk_size)
669-
670-
# Compute the MSE loss for each chunk
671-
for fp16_chunk, int_w_chunk in zip(fp16_chunks, int_w_chunks):
672-
chunk_loss = (
673-
(fp16_chunk.to(device) - int_w_chunk.to(device))
651+
num_elements = 0
652+
653+
# Compute the MSE loss for each batch
654+
for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs):
655+
batch_loss = (
656+
(fp16_batch.to(device) - int_w_batch.to(device))
657+
.view(-1)
674658
.float()
675659
.pow(2)
676660
.sum()
677661
.item()
678662
)
679-
loss += chunk_loss
663+
loss += batch_loss
664+
num_elements += fp16_batch.numel()
680665

681666
# Normalize the loss by the total number of elements
682667
loss /= num_elements

0 commit comments

Comments
 (0)