-
Notifications
You must be signed in to change notification settings - Fork 453
feat: add distributed weight-parallel support to AWQ modifier #2442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Iterator, Literal | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from compressed_tensors.offload.dist_utils import is_distributed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from compressed_tensors.quantization import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| QuantizationStrategy, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| disable_quantization, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -20,6 +21,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from loguru import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from pydantic import ConfigDict, PrivateAttr, field_validator | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch import distributed as dist | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.nn import Module | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.utils._pytree import tree_leaves | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tqdm import tqdm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -497,6 +499,41 @@ def cache_smooth_activations_hook( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "forward", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _reduce_activation_means(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """All-reduce cached activation means across data-parallel ranks. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``_smooth_activation_means`` stores ``(mean, count)`` pairs where | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``mean`` is a running average over the local data partition. To | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| recover a globally-consistent mean we: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 1. Convert each entry back to a sum: ``sum = mean * count`` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 2. All-reduce the sum and count tensors across ranks. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 3. Re-derive the global mean: ``mean = total_sum / total_count`` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| After this method returns every rank holds identical activation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| statistics, which guarantees that the subsequent grid search in | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``_compute_best_scale`` produces the same best scales on every | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rank — eliminating the need for a broadcast step. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| world_size = dist.get_world_size() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if world_size <= 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for name, (mean, count) in self._smooth_activation_means.items(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device = mean.device | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Recover the local sum from the running mean | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| local_sum = mean * count | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| count_tensor = torch.tensor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [count], dtype=torch.int64, device=device | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dist.all_reduce(local_sum, op=dist.ReduceOp.SUM) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total_count = count_tensor.item() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global_mean = local_sum / total_count | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a potential for a division-by-zero error if To prevent this, you should handle the case where
Suggested change
Comment on lines
+523
to
+534
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device = mean.device | |
| # Recover the local sum from the running mean | |
| local_sum = mean * count | |
| count_tensor = torch.tensor( | |
| [count], dtype=torch.int64, device=device | |
| ) | |
| dist.all_reduce(local_sum, op=dist.ReduceOp.SUM) | |
| dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) | |
| total_count = count_tensor.item() | |
| global_mean = local_sum / total_count | |
| orig_device = mean.device | |
| orig_dtype = mean.dtype | |
| # Recover the local sum from the running mean | |
| local_sum = mean * count | |
| count_tensor = torch.tensor( | |
| [count], dtype=torch.int64, device=orig_device | |
| ) | |
| backend = dist.get_backend() if dist.is_initialized() else None | |
| # For NCCL, tensors must be on CUDA; optionally upcast sums to fp32 | |
| if backend in (dist.Backend.NCCL, "nccl"): | |
| if not local_sum.is_cuda: | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError( | |
| "NCCL backend requires CUDA tensors for all_reduce, " | |
| "but CUDA is not available." | |
| ) | |
| reduce_device = torch.device("cuda", torch.cuda.current_device()) | |
| local_sum = local_sum.to(reduce_device, dtype=torch.float32) | |
| count_tensor = count_tensor.to(reduce_device) | |
| else: | |
| # Ensure a supported dtype for reduction | |
| local_sum = local_sum.to(dtype=torch.float32) | |
| # Perform distributed reduction | |
| dist.all_reduce(local_sum, op=dist.ReduceOp.SUM) | |
| dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) | |
| total_count = int(count_tensor.item()) | |
| # Compute global mean and move back to original device/dtype | |
| global_mean = (local_sum / total_count).to(dtype=orig_dtype, device=orig_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is susceptible to a division-by-zero error if count_t.item() is 0, which can occur if a particular mapping or expert is not activated by any sample across all ranks in the distributed group, or if all tokens are masked out by loss_mask. This vulnerability can lead to a NaN result in _reduce_activation_means or a ZeroDivisionError in _compute_loss. A ZeroDivisionError causes a crash (DoS), while NaN weights can corrupt the model. To prevent this, a check should be added to handle the zero count case, returning 0.0 as the total loss would also be zero.
| return (loss_t.item() / count_t.item()) | |
| return (loss_t.item() / count_t.item()) if count_t.item() > 0 else 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_reduce_activation_meansmethod iterates over the keys ofself._smooth_activation_meansand performs collectiveall_reduceoperations. However,self._smooth_activation_meansis populated dynamically during the calibration phase based on which layers are activated by the input data. In a distributed data-parallel setting, different ranks process different batches of data. In sparse models like Mixture of Experts (MoE), it is highly likely that certain experts or layers are activated on some ranks but not others. This results in inconsistent sets of keys inself._smooth_activation_meansacross ranks. Whendist.all_reduceis called inconsistently (i.e., some ranks call it while others do not), it leads to a permanent hang (deadlock) of the distributed process.