Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
f60200a
[Distributed] Add distributed utilities for DDP calibration
EtelisIBM Feb 22, 2026
c4d630d
[Distributed] Add recompute_qparams_from_observer helper
EtelisIBM Feb 22, 2026
89d1ade
[Distributed] Partition weight calibration across DDP ranks
EtelisIBM Feb 22, 2026
ac0cc2a
[Tests] Add unit tests for distributed utilities
EtelisIBM Feb 22, 2026
76cf40f
[Tests] Add multi-GPU integration tests for DDP quantization
EtelisIBM Feb 22, 2026
0f3e1f9
[Examples] Add distributed W8A8 quantization example
EtelisIBM Feb 22, 2026
9975edc
[Distributed] Fix broadcast_module_parameter for CPU-resident models
EtelisIBM Feb 22, 2026
3320812
[Distributed] Refactor DDP activation sync per review feedback
EtelisIBM Feb 23, 2026
87f4b0d
Merge branch 'main' into feature/quantization-modifier-ddp
Etelis Feb 24, 2026
766a70c
Merge remote-tracking branch 'upstream/main' into feature/quantizatio…
EtelisIBM Mar 19, 2026
d44c4ab
[Distributed] Address review feedback for DDP activation observer sync
EtelisIBM Mar 19, 2026
5fa31b2
Merge branch 'main' into feature/quantization-modifier-ddp
HDCharles Mar 19, 2026
0e3a843
[Distributed] Use as_broadcastable and simplify moving-average sync
EtelisIBM Mar 20, 2026
82d808c
Merge branch 'feature/quantization-modifier-ddp' of https://github.co…
EtelisIBM Mar 20, 2026
d6b3575
Update src/llmcompressor/observers/moving_base.py
Etelis Mar 23, 2026
9680d9e
Merge branch 'main' into feature/quantization-modifier-ddp
Etelis Mar 23, 2026
688b309
Merge branch 'main' into feature/quantization-modifier-ddp
kylesayrs Mar 24, 2026
7f31744
Merge branch 'main' into feature/quantization-modifier-ddp
Etelis Mar 24, 2026
4f80617
Merge branch 'main' into feature/quantization-modifier-ddp
Etelis Mar 24, 2026
f959d4b
fix formatting and moving-average test mock path
EtelisIBM Mar 25, 2026
7baf545
Merge branch 'feature/quantization-modifier-ddp' of https://github.co…
EtelisIBM Mar 25, 2026
09e817b
Merge branch 'main' into feature/quantization-modifier-ddp
HDCharles Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#############################################################################
# Distributed W8A8 quantization example with activation observer sync.
# run this with `torchrun --nproc_per_node=2 llama3_8b_w8a8_distributed.py`
# or change nproc_per_node to your desired configuration
#############################################################################

import torch
from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.datasets.utils import get_rank_partition
from llmcompressor.modifiers.quantization import QuantizationModifier

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 2048

###### DDP MODEL LOAD CHANGE #####
init_dist()
with load_offloaded_model():
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, dtype="auto", device_map="auto_offload"
)
##################################

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

###### DDP DATA LOAD CHANGE #####
ds = load_dataset(
DATASET_ID, split=get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)
)
##################################

ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# QuantizationModifier automatically detects torch.distributed and
# all-reduces activation observer statistics at layer boundaries
recipe = [
QuantizationModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]

oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_model(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

print("Saving...")
SAVE_DIR = (
MODEL_ID.rstrip("/").split("/")[-1]
+ "-W8A8-DDP"
+ str(torch.distributed.get_world_size())
)
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

torch.distributed.destroy_process_group()
2 changes: 2 additions & 0 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,12 @@ def on_event(self, state: State, event: Event, **kwargs):

elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
# Run smoothing in case of sequential pipeline
QuantizationMixin.sync_activation_observers(self, state.model)
self._apply_smoothing(state.model)

elif event.type_ == EventType.CALIBRATION_EPOCH_END:
# Run smoothing in case of basic pipeline
QuantizationMixin.sync_activation_observers(self, state.model)
self._apply_smoothing(state.model)

if not self.ended_:
Expand Down
2 changes: 2 additions & 0 deletions src/llmcompressor/modifiers/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,11 @@ def on_event(self, state: State, event: Event, **kwargs):
self.on_start(state, None)

if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
QuantizationMixin.sync_activation_observers(self, state.model)
self.compress_modules()

if event.type_ == EventType.CALIBRATION_EPOCH_END:
QuantizationMixin.sync_activation_observers(self, state.model)
self.compress_modules()

if not self.ended_:
Expand Down
11 changes: 10 additions & 1 deletion src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class QuantizationModifier(Modifier, QuantizationMixin):
the specified module(s) forward pass will emulate quantized execution and the
modifier will be enabled until training is completed.

In DDP mode, activation observer statistics are all-reduced across ranks at
sequential layer boundaries so all ranks share identical quantization parameters.

:param config_groups: dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param targets: list of layer names to quantize if a scheme is provided. Defaults
Expand Down Expand Up @@ -65,14 +68,16 @@ def on_initialize(self, state: State, **kwargs) -> bool:

def on_start(self, state: State, event: Event, **kwargs):
"""
Begin calibrating activations and weights. Calibrate weights only once on start
Begin calibrating activations and weights. Calibrate weights only once
on start. Each rank calibrates weights independently.
"""
self.started_ = True
QuantizationMixin.start_calibration(self, state.model)

named_modules = list(
match_named_modules(state.model, self.resolved_targets, self.ignore)
)

# TODO: this step can be combined with update_weight_zp_scale
# once update_fused_layer_weight_global_scales is removed
# and not required by vLLM
Expand All @@ -95,7 +100,11 @@ def on_event(self, state: State, event: Event, **kwargs):
if not self.started_:
self.on_start(state, None)

if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
QuantizationMixin.sync_activation_observers(self, state.model)

if event.type_ == EventType.CALIBRATION_EPOCH_END:
QuantizationMixin.sync_activation_observers(self, state.model)
if not self.ended_:
self.on_end(state, None)

Expand Down
56 changes: 54 additions & 2 deletions src/llmcompressor/modifiers/quantization/quantization/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
IMPL_ATTR,
KV_CACHE_ATTR,
)
from compressed_tensors.offload.dist_utils import is_distributed
from compressed_tensors.quantization import (
DynamicType,
QuantizationArgs,
Expand All @@ -18,7 +19,7 @@
is_preset_scheme,
preset_name_to_scheme,
)
from compressed_tensors.utils import match_named_modules
from compressed_tensors.utils import match_named_modules, update_offload_parameter
from pydantic import Field, PrivateAttr, field_validator
from torch.utils.hooks import RemovableHandle

Expand All @@ -37,7 +38,11 @@
validate_group_size_divisibility,
)
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.utils import targets_embeddings, untie_word_embeddings
from llmcompressor.utils import (
targets_embeddings,
untie_word_embeddings,
wait_for_comms,
)

__all__ = ["QuantizationMixin"]

Expand Down Expand Up @@ -257,6 +262,53 @@ def end_calibration(self, model: torch.nn.Module):

model.apply(enable_quantization) # keep quantization enabled

def sync_activation_observers(self, model: torch.nn.Module):
"""
All-reduce activation observer min/max values across DDP ranks,
then recompute scale/zp from the global statistics. No-op when
not distributed.

:param model: model containing quantized modules
"""
if not is_distributed():
return

pending_comms = []
modules_to_update = []

for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
for base_name in ("input", "output", "q", "k", "v"):
observer = getattr(module, f"{base_name}_observer", None)
if observer is None:
continue
pending_comms.extend(observer.synchronize())
modules_to_update.append((module, base_name, observer))

wait_for_comms(pending_comms)

# finalize averaging for moving-average observers
for _, _, observer in modules_to_update:
if hasattr(observer, "finalize_synchronize"):
observer.finalize_synchronize()

# recompute qparams from synchronized statistics
for module, base_name, observer in modules_to_update:
# recompute global scale if using TENSOR_GROUP strategy
global_scale = observer.recompute_global_scale()
if global_scale is not None:
update_offload_parameter(
module, f"{base_name}_global_scale", global_scale
)

result = observer.recompute_qparams()
if result is not None:
scale, zero_point = result
update_offload_parameter(module, f"{base_name}_scale", scale)
if hasattr(module, f"{base_name}_zero_point"):
update_offload_parameter(
module, f"{base_name}_zero_point", zero_point
)

def has_config(self) -> bool:
"""
Determine if the user has specified a quantization config on this modifier
Expand Down
94 changes: 93 additions & 1 deletion src/llmcompressor/observers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Optional, Tuple
from typing import List, Optional, Tuple
from weakref import ref

import torch
Expand All @@ -8,6 +8,7 @@
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
from compressed_tensors.registry.registry import RegistryMixin
from compressed_tensors.utils import align_module_device
from torch import distributed as dist

from llmcompressor.observers.helpers import flatten_for_calibration

Expand Down Expand Up @@ -133,6 +134,63 @@ def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]:
with align_module_device(module):
return getattr(module, f"{self.base_name}_{name}", None)

def synchronize(self) -> List[dist.Work]:
"""All-reduce accumulated min/max statistics across DDP ranks.

Issues async all-reduce operations on any accumulated state
(``past_min_vals``, ``past_max_vals``, ``past_global_min_vals``,
``past_global_max_vals``). Memoryless observers return an empty list.

:return: list of async communication handles
"""
comms = []
for attr, op in [
("past_min_vals", dist.ReduceOp.MIN),
("past_max_vals", dist.ReduceOp.MAX),
("past_global_min_vals", dist.ReduceOp.MIN),
("past_global_max_vals", dist.ReduceOp.MAX),
]:
val = getattr(self, attr, None)
if val is not None:
comms.extend(_all_reduce_fp8_safe(val, op=op))
return comms

def recompute_global_scale(self) -> Optional[torch.Tensor]:
"""Recompute global scale from accumulated global min/max state.

Used after :meth:`synchronize` to update the global scale from
globally reduced statistics. Returns ``None`` for memoryless observers.

:return: global scale tensor or ``None``
"""
global_min = getattr(self, "past_global_min_vals", None)
global_max = getattr(self, "past_global_max_vals", None)
if global_min is None or global_max is None:
return None
return generate_gparam(global_min, global_max)

def recompute_qparams(self) -> Optional[ScaleZpTuple]:
"""Recompute scale and zero_point from accumulated min/max state.

Used after :meth:`synchronize` to update quantization parameters from
globally reduced statistics. Returns ``None`` for memoryless observers.

:return: (scale, zero_point) tuple or ``None``
"""
min_vals = getattr(self, "past_min_vals", None)
max_vals = getattr(self, "past_max_vals", None)
if min_vals is None or max_vals is None:
return None

global_scale = self._get_module_param("global_scale")
self._check_has_global_scale(global_scale)
return calculate_qparams(
min_vals=min_vals,
max_vals=max_vals,
quantization_args=self.args,
global_scale=global_scale,
)

def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]):
if (
self.args.strategy == QuantizationStrategy.TENSOR_GROUP
Expand All @@ -142,3 +200,37 @@ def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]):
"Cannot compute scale and zero points "
"without first computing global scale"
)


# FP8 dtypes are not supported by NCCL collective operations.
# Upcast to float32, perform the reduction, then cast back.
_FP8_DTYPES = {
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
}


def _all_reduce_fp8_safe(
tensor: torch.Tensor,
op: dist.ReduceOp,
) -> List[dist.Work]:
"""Issue an all-reduce, upcasting FP8 tensors to float32 first.

Returns a list of async work handles. For FP8 tensors the reduction is
performed synchronously (upcast -> reduce -> downcast) so the returned
list is empty.

:param tensor: tensor to reduce **in-place**
:param op: reduction operation
:return: list of async communication handles
"""
if tensor.dtype in _FP8_DTYPES:
orig_dtype = tensor.dtype
fp32 = tensor.float()
dist.all_reduce(fp32, op=op)
tensor.copy_(fp32.to(orig_dtype))
return []

return [dist.all_reduce(tensor, op=op, async_op=True)]
Loading