Skip to content

[Distributed][SmoothQuant] Add distributed activation scale reduction (#2180)#2471

Merged
kylesayrs merged 6 commits intovllm-project:mainfrom
dzhengAP:feature/smoothquant-distributed
Mar 18, 2026
Merged

[Distributed][SmoothQuant] Add distributed activation scale reduction (#2180)#2471
kylesayrs merged 6 commits intovllm-project:mainfrom
dzhengAP:feature/smoothquant-distributed

Conversation

@dzhengAP
Copy link
Copy Markdown
Contributor

@dzhengAP dzhengAP commented Mar 16, 2026

Summary

Implements distributed support for SmoothQuantModifier as part of the
weight-parallel optimization tracked in #2180, assigned to @dzhengAP.

What this PR does

In a distributed calibration run, each rank observes a disjoint partition
of the calibration dataset. Activation statistics (per-channel min/max)
are collected locally via forward hooks. Before smoothing scales are
computed, _reduce_activation_scales() all-reduces those statistics
across all ranks so every rank has the global activation profile, then
each rank independently computes identical smoothing scales (cheap op,
no weight broadcast needed).

This follows the AWQ strategy described in #2180. Single-GPU behavior
is completely unchanged — all new code is guarded by is_distributed().

Changes

  • _reduce_activation_scales(): all-reduces min/max_channel_vals
    across ranks using async MIN/MAX collectives batched with wait_for_comms
  • _apply_smoothing(): calls _reduce_activation_scales() as first step
  • Unit tests: 5 mock-based tests verifying call contract (no GPU needed)
  • DDP example: examples/quantization_w8a8_int8/smoothquant_ddp_example.py
  • Multi-GPU integration tests verifying weight equivalence vs single-GPU

Test results

  • Unit tests: all 5 passed (pytest -m unit)
  • DDP example: ran successfully on 2x V100 32GB, both ranks completed
    in ~698s, peak GPU mem 1.66 GB per rank

Distributed Speedup Benchmarks

Model: Qwen/Qwen2-7B-Instruct, 512 calibration samples, 4x V100 32GB

GPUs Total Time Peak GPU Mem Speedup
1 GPU 94.1 min 8.93 GB 1.00x
2 GPU 58.7 min 7.06 GB 1.60x
4 GPU 28.7 min 7.06 GB 3.28x

Benchmark script: examples/quantization_w8a8_int8/benchmark_smoothquant_ddp.py

  • ruff: all checks passed

Closes part of #2180
cc @kylesayrs

Add _reduce_activation_scales() to SmoothQuantModifier to support
weight-parallel distributed compression as specified in RFC vllm-project#2180.

In a distributed setting, each rank observes a disjoint partition of
the calibration dataset and collects local per-channel min/max activation
statistics via forward hooks. Before smoothing scales are computed,
_reduce_activation_scales() all-reduces these statistics across all ranks
using MIN/MAX collective ops, ensuring every rank has global activation
statistics.

Following the pattern established for AWQ in the RFC: since the scale
computation is cheap, it is duplicated across ranks rather than computed
on one rank and broadcast, avoiding an extra distributed communication
step on the weight tensors.

This is a strict no-op in single-GPU mode (guarded by is_distributed()).

Also adds:
- Unit tests (mock-based, no GPU) verifying the all_reduce call contract
- DDP example script for examples/quantization_w8a8_int8/
- Multi-GPU integration tests verifying weight equivalence vs single-GPU

Testing:
- Unit tests: all 5 passed (pytest -m unit)
- DDP example: ran successfully on 2x V100 32GB, both ranks completed
  in ~698s, peak GPU mem 1.66 GB per rank
- ruff: all checks passed

Relates to: vllm-project#2180

Signed-off-by: David Zheng <dzheng@apple.com>
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces distributed capabilities to the SmoothQuant quantization process. It enables efficient calibration in a distributed setting by allowing each GPU to process a disjoint subset of the calibration data, then aggregating the activation statistics across all ranks. This ensures that all ranks compute identical smoothing scales, facilitating consistent and scalable W8A8 quantization for large language models.

Highlights

  • Distributed SmoothQuant Support: Implemented distributed support for SmoothQuantModifier to enable weight-parallel quantization, addressing issue [RFC] [Performance Refactor][Distributed] Sequential Onloading with Data-Parallel Calibration and Weight-Parallel Optimization #2180.
  • Activation Scale Reduction: Introduced _reduce_activation_scales() to perform an all-reduce operation on per-channel min/max activation statistics across all ranks, ensuring global consistency for smoothing scale computation.
  • Smoothing Application Update: Modified _apply_smoothing() to call _reduce_activation_scales() as its initial step, guaranteeing that smoothing scales are computed using globally aggregated statistics.
  • New DDP Example: Added a new distributed data parallel (DDP) example script (smoothquant_ddp_example.py) demonstrating distributed SmoothQuant and GPTQ W8A8 quantization.
  • Comprehensive Testing: Included 5 mock-based unit tests and multi-GPU integration tests to verify the correctness and behavior of distributed SmoothQuant, including weight equivalence against single-GPU runs.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/quantization_w8a8_int8/smoothquant_ddp_example.py
    • Added a new example script for distributed SmoothQuant and GPTQ W8A8 quantization.
  • src/llmcompressor/modifiers/transform/smoothquant/base.py
    • Imported torch.distributed and is_distributed for distributed functionality.
    • Imported wait_for_comms for managing asynchronous communication operations.
    • Added _reduce_activation_scales method to perform all-reduce operations on activation statistics.
    • Modified _apply_smoothing to incorporate the distributed activation scale reduction.
  • tests/llmcompressor/modifiers/transform/smoothquant/test_smoothquant_distributed.py
    • Added unit tests to verify the _reduce_activation_scales method's behavior in distributed and non-distributed contexts.
    • Added unit tests to confirm the correct number and type of all-reduce calls for multiple layers.
    • Added a unit test to ensure _apply_smoothing calls _reduce_activation_scales first.
    • Added a unit test to check wait_for_comms is called even with empty scales.
    • Added multi-GPU integration tests to validate the DDP example script's execution and output.
    • Added multi-GPU integration tests to verify that distributed SmoothQuant weights match single-GPU reference weights.
Activity
  • Unit tests: all 5 passed.
  • DDP example: ran successfully on 2x V100 32GB, both ranks completed in ~698s, peak GPU mem 1.66 GB per rank.
  • Ruff: all checks passed.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the documentation Improvements or additions to documentation label Mar 16, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces distributed support for the SmoothQuantModifier, which is a great enhancement. The core logic for all-reducing activation statistics across ranks is well-implemented and efficient. The addition of comprehensive unit and integration tests, including a multi-GPU correctness test, is excellent and ensures the reliability of the new feature. I have one suggestion for the new DDP example script to prevent a potential race condition during saving.

Prevent race condition where all ranks write to the same save directory.

Addresses gemini-code-assist review on PR vllm-project#2471

Signed-off-by: David Zheng <dzheng@apple.com>
dispatch_model and generate require all ranks to participate.
Add dist.barrier() before generation and only log output on rank 0.

Signed-off-by: David Zheng <dzheng@apple.com>
@dzhengAP
Copy link
Copy Markdown
Contributor Author

Updated in 6646930:

  • Fixed race condition: save is now guarded by if rank == 0 (per Gemini review)
  • Fixed NCCL broadcast timeout in sample generation: dispatch_model and generate require all ranks to participate, added dist.barrier() before generation and only log output on rank 0

@HDCharles HDCharles self-requested a review March 16, 2026 15:32
@HDCharles HDCharles assigned HDCharles and unassigned HDCharles Mar 16, 2026
@HDCharles HDCharles requested review from HDCharles and removed request for HDCharles March 16, 2026 15:38
@HDCharles HDCharles assigned HDCharles and unassigned HDCharles Mar 16, 2026
@HDCharles HDCharles removed their request for review March 16, 2026 15:41
@HDCharles HDCharles assigned HDCharles and unassigned HDCharles Mar 16, 2026
@HDCharles HDCharles self-requested a review March 16, 2026 15:44
@HDCharles HDCharles assigned HDCharles and unassigned HDCharles Mar 16, 2026
Copy link
Copy Markdown
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really great! I think I'd like to see some basic model tests to validate that accuracy does not degrade. For example, see: #2457

@HDCharles HDCharles assigned HDCharles and unassigned HDCharles Mar 17, 2026
Copy link
Copy Markdown
Collaborator

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few nits, looks good otherwise

@HDCharles HDCharles assigned HDCharles and unassigned HDCharles Mar 17, 2026
@HDCharles HDCharles requested a review from kylesayrs March 18, 2026 20:09
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 18, 2026

The quality checks have failed. Please run make style and make quality under
the root directory to adddress the lint failures. You will need to install the
dev optional install to get the required linting packages:
https://github.com/vllm-project/llm-compressor/blob/main/CONTRIBUTING.md

Signed-off-by: David Zheng <dqzheng1996@gmail.com>
@dzhengAP dzhengAP force-pushed the feature/smoothquant-distributed branch from d2f8f6e to 500a1e2 Compare March 18, 2026 20:36
@HDCharles HDCharles added ready When a PR is ready for review smoothquant For any issue / PR related to SmoothQuant support enhancement New feature or request dist Work pertaining to distributed work and removed documentation Improvements or additions to documentation labels Mar 18, 2026
@mergify mergify bot added the documentation Improvements or additions to documentation label Mar 18, 2026
@dzhengAP
Copy link
Copy Markdown
Contributor Author

@kylesayrs all review comments addressed, quality checks passing, 2 approvals. Ready for merge when you have a chance!

@kylesayrs kylesayrs merged commit a3d2a3f into vllm-project:main Mar 18, 2026
15 of 23 checks passed
dzhengAP pushed a commit to dzhengAP/llm-compressor that referenced this pull request Mar 19, 2026
Prevent race condition where all ranks write to the same save directory.

Addresses gemini-code-assist review on PR vllm-project#2471

Signed-off-by: David Zheng <dzheng@apple.com>
dzhengAP pushed a commit to dzhengAP/llm-compressor that referenced this pull request Mar 19, 2026
- Move is_distributed() guard to _apply_smoothing() hotpath so
  _reduce_activation_scales() only handles the distributed case,
  improving readability for single-GPU readers
- Remove redundant unit tests (subsumed by 2n_calls test; empty
  scales test unnecessary per reviewer feedback)
- Remove test_smoothquant_ddp_script_runs_cleanly (too expensive for CI)
- Switch integration test model to nm-testing/tinysmokellama-3.2
  (CI-friendly tiny model per HDCharles suggestion)
- Switch DDP example model to Qwen/Qwen2-7B-Instruct (DDP more
  meaningful for larger models)
- Fix --nproc arg conflict with torchrun, rename to --num_gpus
- Add benchmark_smoothquant_ddp.py for reproducing speedup numbers

Distributed speedup on 4x V100 32GB (Qwen2-7B-Instruct, 512 samples):
  1 GPU:  94.1 min | 8.93 GB peak mem | 1.00x
  2 GPU:  58.7 min | 7.06 GB peak mem | 1.60x
  4 GPU:  28.7 min | 7.06 GB peak mem | 3.28x

Addresses review comments from HDCharles on PR vllm-project#2471

Signed-off-by: David Zheng <dqzheng1996@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dist Work pertaining to distributed work documentation Improvements or additions to documentation enhancement New feature or request ready When a PR is ready for review smoothquant For any issue / PR related to SmoothQuant support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants