Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
f60200a
[Distributed] Add distributed utilities for DDP calibration
EtelisIBM Feb 22, 2026
c4d630d
[Distributed] Add recompute_qparams_from_observer helper
EtelisIBM Feb 22, 2026
89d1ade
[Distributed] Partition weight calibration across DDP ranks
EtelisIBM Feb 22, 2026
ac0cc2a
[Tests] Add unit tests for distributed utilities
EtelisIBM Feb 22, 2026
76cf40f
[Tests] Add multi-GPU integration tests for DDP quantization
EtelisIBM Feb 22, 2026
0f3e1f9
[Examples] Add distributed W8A8 quantization example
EtelisIBM Feb 22, 2026
9975edc
[Distributed] Fix broadcast_module_parameter for CPU-resident models
EtelisIBM Feb 22, 2026
3320812
[Distributed] Refactor DDP activation sync per review feedback
EtelisIBM Feb 23, 2026
87f4b0d
Merge branch 'main' into feature/quantization-modifier-ddp
Etelis Feb 24, 2026
766a70c
Merge remote-tracking branch 'upstream/main' into feature/quantizatio…
EtelisIBM Mar 19, 2026
d44c4ab
[Distributed] Address review feedback for DDP activation observer sync
EtelisIBM Mar 19, 2026
5fa31b2
Merge branch 'main' into feature/quantization-modifier-ddp
HDCharles Mar 19, 2026
0e3a843
[Distributed] Use as_broadcastable and simplify moving-average sync
EtelisIBM Mar 20, 2026
82d808c
Merge branch 'feature/quantization-modifier-ddp' of https://github.co…
EtelisIBM Mar 20, 2026
d6b3575
Update src/llmcompressor/observers/moving_base.py
Etelis Mar 23, 2026
9680d9e
Merge branch 'main' into feature/quantization-modifier-ddp
Etelis Mar 23, 2026
688b309
Merge branch 'main' into feature/quantization-modifier-ddp
kylesayrs Mar 24, 2026
7f31744
Merge branch 'main' into feature/quantization-modifier-ddp
Etelis Mar 24, 2026
4f80617
Merge branch 'main' into feature/quantization-modifier-ddp
Etelis Mar 24, 2026
f959d4b
fix formatting and moving-average test mock path
EtelisIBM Mar 25, 2026
7baf545
Merge branch 'feature/quantization-modifier-ddp' of https://github.co…
EtelisIBM Mar 25, 2026
09e817b
Merge branch 'main' into feature/quantization-modifier-ddp
HDCharles Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import torch.distributed as dist
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 2048

# Initialize distributed.
# Usage: torchrun --nproc_per_node=2 llama3_8b_w8a8_distributed.py
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)

if rank == 0:
print(f"Running distributed quantization with {world_size} GPUs")

# Load model to CPU for sequential onloading.
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map=None,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Load and partition dataset across ranks.
# Each rank loads a disjoint slice of the calibration data.
samples_per_rank = NUM_CALIBRATION_SAMPLES // world_size
start = samples_per_rank * rank
end = start + samples_per_rank

ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[{start}:{end}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm to run.
# QuantizationModifier automatically detects torch.distributed and:
# * partitions weight calibration across ranks
# * all-reduces activation observer statistics at layer boundaries
recipe = [
QuantizationModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]

# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=samples_per_rank,
)

# Save to disk compressed (rank 0 only).
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-distributed"
if rank == 0:
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
print(f"Model saved to {SAVE_DIR}")

dist.destroy_process_group()
36 changes: 36 additions & 0 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"calibrate_query_hook",
"calibrate_key_hook",
"calibrate_value_hook",
"recompute_qparams_from_observer",
]


Expand Down Expand Up @@ -235,6 +236,41 @@ def calibrate_value_hook(module: Module, value_states: torch.Tensor):
calibrate_activations(module, value_states, base_name="v")


def recompute_qparams_from_observer(module: Module, base_name: str):
"""
Recompute scale and zero_point from an observer's accumulated
past_min_vals/past_max_vals. Used after DDP all-reduce to update
qparams from synchronized statistics.

:param module: module with quantization parameters
:param base_name: "input", "output", "q", "k", or "v"
"""
from compressed_tensors.quantization.utils import calculate_qparams

observer: Observer = getattr(module, f"{base_name}_observer", None)
if observer is None:
return

min_vals = getattr(observer, "past_min_vals", None)
max_vals = getattr(observer, "past_max_vals", None)

if min_vals is None or max_vals is None:
return

global_scale = getattr(module, f"{base_name}_global_scale", None)

scale, zero_point = calculate_qparams(
min_vals=min_vals,
max_vals=max_vals,
quantization_args=observer.args,
global_scale=global_scale,
)

update_offload_parameter(module, f"{base_name}_scale", scale)
if hasattr(module, f"{base_name}_zero_point"):
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)


def apply_calibration_status(module: Module):
scheme = getattr(module, "quantization_scheme", None)
if not scheme:
Expand Down
112 changes: 110 additions & 2 deletions src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,21 @@
from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.quantization.calibration import (
recompute_qparams_from_observer,
update_weight_global_scale,
update_weight_zp_scale,
)
from llmcompressor.modifiers.quantization.quantization.mixin import QuantizationMixin
from llmcompressor.modifiers.utils import update_fused_layer_weight_global_scales
from llmcompressor.utils.distributed import (
all_reduce_max,
all_reduce_min,
broadcast_module_parameter,
build_module_to_rank_map,
get_rank,
is_distributed,
partition_modules_by_weight_size,
)

__all__ = ["QuantizationModifier"]

Expand All @@ -20,6 +30,9 @@ class QuantizationModifier(Modifier, QuantizationMixin):
the specified module(s) forward pass will emulate quantized execution and the
modifier will be enabled until training is completed.

In DDP mode, weight calibration is partitioned across ranks and activation
observer statistics are all-reduced at sequential layer boundaries.

:param config_groups: dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param targets: list of layer names to quantize if a scheme is provided. Defaults
Expand Down Expand Up @@ -65,14 +78,23 @@ def on_initialize(self, state: State, **kwargs) -> bool:

def on_start(self, state: State, event: Event, **kwargs):
"""
Begin calibrating activations and weights. Calibrate weights only once on start
Begin calibrating activations and weights. Calibrate weights only once
on start. In DDP mode, weight calibration is partitioned across ranks.
"""
self.started_ = True
QuantizationMixin.start_calibration(self, state.model)

named_modules = list(
match_named_modules(state.model, self.resolved_targets, self.ignore)
)

if is_distributed():
self._calibrate_weights_distributed(state.model, named_modules)
else:
self._calibrate_weights_single(state.model, named_modules)

def _calibrate_weights_single(self, model, named_modules):
"""Original single-process weight calibration."""
# TODO: this step can be combined with update_weight_zp_scale
# once update_fused_layer_weight_global_scales is removed
# and not required by vLLM
Expand All @@ -84,21 +106,107 @@ def on_start(self, state: State, event: Event, **kwargs):
# on targeted modules, we need to run on all modules.
# Because this call is idempotent, setting all global_scales to the
# min value, it is ok to run potentially multiple times for all modules
for module in state.model.modules():
for module in model.modules():
update_fused_layer_weight_global_scales(module)

for _, module in tqdm.tqdm(named_modules, desc="Calibrating weights"):
update_weight_zp_scale(module)

def _calibrate_weights_distributed(self, model, named_modules):
"""
DDP-partitioned weight calibration. Each rank calibrates a subset of
modules and broadcasts results to all ranks.
"""
module_to_rank = build_module_to_rank_map(named_modules)
my_modules = partition_modules_by_weight_size(named_modules)
rank = get_rank()

# compute global_scale for assigned modules only
for _, module in tqdm.tqdm(
my_modules, desc=f"[Rank {rank}] Updating global scales"
):
update_weight_global_scale(module)

# broadcast global_scales so all ranks can run the fuse step
for _, module in named_modules:
src_rank = module_to_rank[module]
broadcast_module_parameter(module, "weight_global_scale", src_rank)

# fuse global_scales (all ranks, idempotent)
for module in model.modules():
update_fused_layer_weight_global_scales(module)

# compute scale/zp for assigned modules only
for _, module in tqdm.tqdm(
my_modules, desc=f"[Rank {rank}] Calibrating weights"
):
update_weight_zp_scale(module)

# broadcast scale/zp to all ranks
for _, module in named_modules:
src_rank = module_to_rank[module]
broadcast_module_parameter(module, "weight_scale", src_rank)
if hasattr(module, "weight_zero_point"):
broadcast_module_parameter(module, "weight_zero_point", src_rank)

def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.CALIBRATION_EPOCH_START:
if not self.started_:
self.on_start(state, None)

if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
self._sync_activation_observers(state.model)

if event.type_ == EventType.CALIBRATION_EPOCH_END:
self._sync_activation_observers(state.model)
if not self.ended_:
self.on_end(state, None)

def _sync_activation_observers(self, model):
"""
All-reduce activation observer min/max values across DDP ranks,
then recompute scale/zp from the global statistics.
No-op if not distributed.
"""
if not is_distributed():
return

for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
for base_name in ("input", "output", "q", "k", "v"):
observer = getattr(module, f"{base_name}_observer", None)
if observer is None:
continue

# all-reduce accumulated min/max across ranks
if (
hasattr(observer, "past_min_vals")
and observer.past_min_vals is not None
):
observer.past_min_vals = all_reduce_min(observer.past_min_vals)
if (
hasattr(observer, "past_max_vals")
and observer.past_max_vals is not None
):
observer.past_max_vals = all_reduce_max(observer.past_max_vals)

# all-reduce global min/max (TENSOR_GROUP strategy)
if (
hasattr(observer, "past_global_min_vals")
and observer.past_global_min_vals is not None
):
observer.past_global_min_vals = all_reduce_min(
observer.past_global_min_vals
)
if (
hasattr(observer, "past_global_max_vals")
and observer.past_global_max_vals is not None
):
observer.past_global_max_vals = all_reduce_max(
observer.past_global_max_vals
)

recompute_qparams_from_observer(module, base_name)

def on_end(self, state: State, event: Event, **kwargs):
"""
Finish calibrating by removing observers and calibration hooks
Expand Down
1 change: 1 addition & 0 deletions src/llmcompressor/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .dev import *
from .helpers import *
from .dist import *
from .distributed import *
Loading
Loading