Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Iterator, Literal

import torch
from compressed_tensors.offload.dist_utils import is_distributed
from compressed_tensors.quantization import (
QuantizationStrategy,
disable_quantization,
Expand All @@ -20,6 +21,7 @@
)
from loguru import logger
from pydantic import ConfigDict, PrivateAttr, field_validator
from torch import distributed as dist
from torch.nn import Module
from torch.utils._pytree import tree_leaves
from tqdm import tqdm
Expand Down Expand Up @@ -497,6 +499,41 @@ def cache_smooth_activations_hook(
"forward",
)

def _reduce_activation_means(self) -> None:
"""All-reduce cached activation means across data-parallel ranks.

``_smooth_activation_means`` stores ``(mean, count)`` pairs where
``mean`` is a running average over the local data partition. To
recover a globally-consistent mean we:

1. Convert each entry back to a sum: ``sum = mean * count``
2. All-reduce the sum and count tensors across ranks.
3. Re-derive the global mean: ``mean = total_sum / total_count``

After this method returns every rank holds identical activation
statistics, which guarantees that the subsequent grid search in
``_compute_best_scale`` produces the same best scales on every
rank — eliminating the need for a broadcast step.
"""
world_size = dist.get_world_size()
if world_size <= 1:
return

for name, (mean, count) in self._smooth_activation_means.items():
device = mean.device
# Recover the local sum from the running mean
local_sum = mean * count
count_tensor = torch.tensor(
[count], dtype=torch.int64, device=device
)

dist.all_reduce(local_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM)
Comment on lines +522 to +531
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The _reduce_activation_means method iterates over the keys of self._smooth_activation_means and performs collective all_reduce operations. However, self._smooth_activation_means is populated dynamically during the calibration phase based on which layers are activated by the input data. In a distributed data-parallel setting, different ranks process different batches of data. In sparse models like Mixture of Experts (MoE), it is highly likely that certain experts or layers are activated on some ranks but not others. This results in inconsistent sets of keys in self._smooth_activation_means across ranks. When dist.all_reduce is called inconsistently (i.e., some ranks call it while others do not), it leads to a permanent hang (deadlock) of the distributed process.


total_count = count_tensor.item()
global_mean = local_sum / total_count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a potential for a division-by-zero error if total_count is 0. This can happen if a module (like an expert in a MoE model) is not activated by any calibration samples across all distributed ranks. This would lead to NaN values for global_mean, which will cause issues later.

To prevent this, you should handle the case where total_count is zero. When total_count is 0, the global sum of activations (local_sum after all-reduce) will also be 0, so setting global_mean to a tensor of zeros is a safe approach.

Suggested change
global_mean = local_sum / total_count
global_mean = local_sum / total_count if total_count > 0 else torch.zeros_like(local_sum)

Comment on lines +523 to +534
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_smooth_activation_means entries are stored on CPU (see the hook storing act_mean.cpu()), so mean.device will typically be cpu here. dist.all_reduce on CPU tensors only works with a CPU-capable backend (e.g., gloo); if the distributed context is initialized with NCCL (common for multi-GPU), this will raise at runtime. Consider moving local_sum/count_tensor to a CUDA device for the reduction (or using a dedicated gloo process group for CPU reductions), and ensure the dtype is supported for the chosen backend (e.g., cast sums to fp32/fp64 before all-reduce).

Suggested change
device = mean.device
# Recover the local sum from the running mean
local_sum = mean * count
count_tensor = torch.tensor(
[count], dtype=torch.int64, device=device
)
dist.all_reduce(local_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM)
total_count = count_tensor.item()
global_mean = local_sum / total_count
orig_device = mean.device
orig_dtype = mean.dtype
# Recover the local sum from the running mean
local_sum = mean * count
count_tensor = torch.tensor(
[count], dtype=torch.int64, device=orig_device
)
backend = dist.get_backend() if dist.is_initialized() else None
# For NCCL, tensors must be on CUDA; optionally upcast sums to fp32
if backend in (dist.Backend.NCCL, "nccl"):
if not local_sum.is_cuda:
if not torch.cuda.is_available():
raise RuntimeError(
"NCCL backend requires CUDA tensors for all_reduce, "
"but CUDA is not available."
)
reduce_device = torch.device("cuda", torch.cuda.current_device())
local_sum = local_sum.to(reduce_device, dtype=torch.float32)
count_tensor = count_tensor.to(reduce_device)
else:
# Ensure a supported dtype for reduction
local_sum = local_sum.to(dtype=torch.float32)
# Perform distributed reduction
dist.all_reduce(local_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM)
total_count = int(count_tensor.item())
# Compute global mean and move back to original device/dtype
global_mean = (local_sum / total_count).to(dtype=orig_dtype, device=orig_device)

Copilot uses AI. Check for mistakes.
self._smooth_activation_means[name] = (global_mean, total_count)

@torch.no_grad()
def _apply_smoothing(self, model: Module) -> None:
"""
Expand All @@ -506,6 +543,13 @@ def _apply_smoothing(self, model: Module) -> None:

:param model: model to apply smoothing to
"""
# ── Distributed: all-reduce activation means across DP ranks ──
# Each rank has computed activation means from its local data
# partition. We average them so that every rank uses identical
# statistics (and therefore computes the same best scales).
if is_distributed():
self._reduce_activation_means()

# NOTE: When using SequentialPipeline, not all the mappings
# will have cached activations in the segment being updated
mappings_to_smooth = [
Expand Down Expand Up @@ -830,6 +874,20 @@ def _compute_loss(
)
num_elements += fp16_batch.numel()

# ── Distributed: all-reduce MSE loss across DP ranks ──
# Each rank has computed loss on its local data partition.
# Sum losses and element counts across ranks so every rank
# independently arrives at the same best_scales.
if is_distributed():
device = fp16_outputs[0].device if fp16_outputs else "cpu"
loss_t = torch.tensor([loss], dtype=torch.float64, device=device)
count_t = torch.tensor(
[num_elements], dtype=torch.int64, device=device
)
dist.all_reduce(loss_t, op=dist.ReduceOp.SUM)
dist.all_reduce(count_t, op=dist.ReduceOp.SUM)
return (loss_t.item() / count_t.item())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

This line is susceptible to a division-by-zero error if count_t.item() is 0, which can occur if a particular mapping or expert is not activated by any sample across all ranks in the distributed group, or if all tokens are masked out by loss_mask. This vulnerability can lead to a NaN result in _reduce_activation_means or a ZeroDivisionError in _compute_loss. A ZeroDivisionError causes a crash (DoS), while NaN weights can corrupt the model. To prevent this, a check should be added to handle the zero count case, returning 0.0 as the total loss would also be zero.

Suggested change
return (loss_t.item() / count_t.item())
return (loss_t.item() / count_t.item()) if count_t.item() > 0 else 0.0


# Normalize the loss by the total number of elements
return (loss / num_elements).item()

Expand Down
Loading