Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions examples/quantization_w8a8_int8/smoothquant_ddp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""
Distributed SmoothQuant + GPTQ W8A8 quantization using Data-Parallel Calibration.

Run with:
torchrun --standalone --nproc_per_node=NUM_GPUS smoothquant_ddp_example.py

Each rank loads a disjoint partition of the calibration dataset.
SmoothQuantModifier all-reduces per-channel activation statistics across ranks
before computing smoothing scales (identical on every rank, no weight broadcast
needed). GPTQModifier then applies distributed W8A8 quantization.

This script intentionally mirrors the structure of
examples/quantization_w4a16/llama3_ddp_example.py so it is easy to diff.
"""

import time

import torch
import torch.distributed as dist
from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model
from datasets import load_dataset
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.datasets.utils import get_rank_partition
from llmcompressor.modifiers.gptq import GPTQModifier
from llmcompressor.modifiers.transform.smoothquant import SmoothQuantModifier

# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# ---------------------------------------------------------------------------
# DDP init + model load
# ---------------------------------------------------------------------------
init_dist()

with load_offloaded_model():
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
device_map="auto_offload",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# ---------------------------------------------------------------------------
# Dataset: each rank gets a disjoint slice
# ---------------------------------------------------------------------------
ds = load_dataset(
DATASET_ID,
split=get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES),
)
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(preprocess)
ds = ds.map(tokenize, remove_columns=ds.column_names)

# ---------------------------------------------------------------------------
# Recipe: SmoothQuant (distributed-aware) + GPTQ W8A8
# ---------------------------------------------------------------------------
recipe = [
SmoothQuantModifier(smoothing_strength=0.8),
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]

# ---------------------------------------------------------------------------
# Run oneshot
# ---------------------------------------------------------------------------
torch.cuda.reset_peak_memory_stats()
start_time = time.time()

oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

elapsed = time.time() - start_time
peak_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)

rank = dist.get_rank()
logger.info(
f"[Rank {rank}] Done in {elapsed:.1f}s | Peak GPU mem: {peak_mem_gb:.2f} GB"
)

# ---------------------------------------------------------------------------
# Sample generation (rank 0 only)
# ---------------------------------------------------------------------------
if rank == 0:
logger.info("\n========== SAMPLE GENERATION ==========")
dispatch_model(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {k: v.to(model.device) for k, v in sample.items()}
output = model.generate(**sample, max_new_tokens=50)
logger.info(tokenizer.decode(output[0]))
logger.info("========================================\n")

# ---------------------------------------------------------------------------
# Save (rank 0 only — save_pretrained handles dist internally)
# ---------------------------------------------------------------------------
SAVE_DIR = (
MODEL_ID.rstrip("/").split("/")[-1]
+ "-W8A8-SmoothQuant-DDP"
+ str(dist.get_world_size())
)
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
logger.info(f"[Rank {rank}] Saved to {SAVE_DIR}")

dist.destroy_process_group()
44 changes: 44 additions & 0 deletions src/llmcompressor/modifiers/transform/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import Callable

import torch
import torch.distributed as dist
from compressed_tensors.offload import update_offload_parameter
from compressed_tensors.offload.dist_utils import is_distributed
from compressed_tensors.utils import match_modules_set, match_named_modules
from loguru import logger
from pydantic import ConfigDict, Field
Expand All @@ -15,6 +17,7 @@
get_layer_mappings_from_architecture,
handle_mapping_resolution_errors,
)
from llmcompressor.utils.dist import wait_for_comms
from llmcompressor.utils.pytorch.module import get_module_to_name_dict

MINIMUM_SMOOTHING_SCALE = 1e-5
Expand Down Expand Up @@ -290,6 +293,40 @@ def hook_fn(module, inp, out):
layer = mapping.smooth_layer
self.register_hook(layer, create_hook_fn(name), "forward")

def _reduce_activation_scales(self):
"""
In a distributed setting, all-reduce the per-channel min/max activation
statistics collected by each rank during calibration.

Each rank observes a disjoint partition of the calibration dataset, so the
global channel-wise min/max must be gathered across all ranks before smoothing
scales are computed. We use ``dist.all_reduce`` with MIN/MAX ops so that
every rank ends up with identical statistics and can independently compute the
same smoothing scales without an extra broadcast of the final scale tensors.

This is a no-op when not running in a distributed context.
"""
if not is_distributed():
return

pending_comms = []
for layer_name, scale in self.scales_.items():
pending_comms.append(
dist.all_reduce(
scale.min_channel_vals,
op=dist.ReduceOp.MIN,
async_op=True,
)
)
pending_comms.append(
dist.all_reduce(
scale.max_channel_vals,
op=dist.ReduceOp.MAX,
async_op=True,
)
)
wait_for_comms(pending_comms)

@torch.no_grad()
def _apply_smoothing(self, model: Module):
"""
Expand All @@ -299,8 +336,15 @@ def _apply_smoothing(self, model: Module):
Y = (Xdiag(scales)^(-1) * diag(scales)W) where W is the to_balance weights and
X is the to_smooth weights

In a distributed setting, activation statistics are first all-reduced across
ranks so that every rank computes identical smoothing scales. The scale
computation itself is duplicated across ranks (cheap), avoiding the need for
a broadcast of the final weight tensors.

This modifies the weights of the model in-place.
"""
self._reduce_activation_scales()

for mapping in self.resolved_mappings_:
if mapping.smooth_name not in self.scales_:
continue
Expand Down
Loading
Loading