diff --git a/examples/big_models_with_sequential_onloading/llama3_8b_w8a8_distributed.py b/examples/big_models_with_sequential_onloading/llama3_8b_w8a8_distributed.py new file mode 100644 index 0000000000..10d52f88f9 --- /dev/null +++ b/examples/big_models_with_sequential_onloading/llama3_8b_w8a8_distributed.py @@ -0,0 +1,100 @@ +############################################################################# +# Distributed W8A8 quantization example with activation observer sync. +# run this with `torchrun --nproc_per_node=2 llama3_8b_w8a8_distributed.py` +# or change nproc_per_node to your desired configuration +############################################################################# + +import torch +from compressed_tensors.offload import dispatch_model, init_dist, load_offloaded_model +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.datasets.utils import get_rank_partition +from llmcompressor.modifiers.quantization import QuantizationModifier + +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" + +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +NUM_CALIBRATION_SAMPLES = 256 +MAX_SEQUENCE_LENGTH = 2048 + +###### DDP MODEL LOAD CHANGE ##### +init_dist() +with load_offloaded_model(): + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, dtype="auto", device_map="auto_offload" + ) +################################## + +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +###### DDP DATA LOAD CHANGE ##### +ds = load_dataset( + DATASET_ID, split=get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES) +) +################################## + +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# QuantizationModifier automatically detects torch.distributed and +# all-reduces activation observer statistics at layer boundaries +recipe = [ + QuantizationModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]), +] + +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_model(model) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to(model.device) for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +print("Saving...") +SAVE_DIR = ( + MODEL_ID.rstrip("/").split("/")[-1] + + "-W8A8-DDP" + + str(torch.distributed.get_world_size()) +) +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) + +torch.distributed.destroy_process_group() diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index ef1e357f0f..a1e3236c21 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -264,10 +264,12 @@ def on_event(self, state: State, event: Event, **kwargs): elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: # Run smoothing in case of sequential pipeline + QuantizationMixin.sync_activation_observers(self, state.model) self._apply_smoothing(state.model) elif event.type_ == EventType.CALIBRATION_EPOCH_END: # Run smoothing in case of basic pipeline + QuantizationMixin.sync_activation_observers(self, state.model) self._apply_smoothing(state.model) if not self.ended_: diff --git a/src/llmcompressor/modifiers/gptq/base.py b/src/llmcompressor/modifiers/gptq/base.py index 384ef355ad..328bb77f4c 100644 --- a/src/llmcompressor/modifiers/gptq/base.py +++ b/src/llmcompressor/modifiers/gptq/base.py @@ -224,9 +224,11 @@ def on_event(self, state: State, event: Event, **kwargs): self.on_start(state, None) if event.type_ == EventType.SEQUENTIAL_EPOCH_END: + QuantizationMixin.sync_activation_observers(self, state.model) self.compress_modules() if event.type_ == EventType.CALIBRATION_EPOCH_END: + QuantizationMixin.sync_activation_observers(self, state.model) self.compress_modules() if not self.ended_: diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index ffd64377f7..30abd8337d 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -20,6 +20,9 @@ class QuantizationModifier(Modifier, QuantizationMixin): the specified module(s) forward pass will emulate quantized execution and the modifier will be enabled until training is completed. + In DDP mode, activation observer statistics are all-reduced across ranks at + sequential layer boundaries so all ranks share identical quantization parameters. + :param config_groups: dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized. :param targets: list of layer names to quantize if a scheme is provided. Defaults @@ -65,7 +68,8 @@ def on_initialize(self, state: State, **kwargs) -> bool: def on_start(self, state: State, event: Event, **kwargs): """ - Begin calibrating activations and weights. Calibrate weights only once on start + Begin calibrating activations and weights. Calibrate weights only once + on start. Each rank calibrates weights independently. """ self.started_ = True QuantizationMixin.start_calibration(self, state.model) @@ -73,6 +77,7 @@ def on_start(self, state: State, event: Event, **kwargs): named_modules = list( match_named_modules(state.model, self.resolved_targets, self.ignore) ) + # TODO: this step can be combined with update_weight_zp_scale # once update_fused_layer_weight_global_scales is removed # and not required by vLLM @@ -95,7 +100,11 @@ def on_event(self, state: State, event: Event, **kwargs): if not self.started_: self.on_start(state, None) + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: + QuantizationMixin.sync_activation_observers(self, state.model) + if event.type_ == EventType.CALIBRATION_EPOCH_END: + QuantizationMixin.sync_activation_observers(self, state.model) if not self.ended_: self.on_end(state, None) diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 08f9d75842..936c47db76 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -5,6 +5,7 @@ IMPL_ATTR, KV_CACHE_ATTR, ) +from compressed_tensors.offload.dist_utils import is_distributed from compressed_tensors.quantization import ( DynamicType, QuantizationArgs, @@ -18,7 +19,7 @@ is_preset_scheme, preset_name_to_scheme, ) -from compressed_tensors.utils import match_named_modules +from compressed_tensors.utils import match_named_modules, update_offload_parameter from pydantic import Field, PrivateAttr, field_validator from torch.utils.hooks import RemovableHandle @@ -37,7 +38,11 @@ validate_group_size_divisibility, ) from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.utils import targets_embeddings, untie_word_embeddings +from llmcompressor.utils import ( + targets_embeddings, + untie_word_embeddings, + wait_for_comms, +) __all__ = ["QuantizationMixin"] @@ -257,6 +262,48 @@ def end_calibration(self, model: torch.nn.Module): model.apply(enable_quantization) # keep quantization enabled + def sync_activation_observers(self, model: torch.nn.Module): + """ + All-reduce activation observer min/max values across DDP ranks, + then recompute scale/zp from the global statistics. No-op when + not distributed. + + :param model: model containing quantized modules + """ + if not is_distributed(): + return + + pending_comms = [] + modules_to_update = [] + + for _, module in match_named_modules(model, self.resolved_targets, self.ignore): + for base_name in ("input", "output", "q", "k", "v"): + observer = getattr(module, f"{base_name}_observer", None) + if observer is None: + continue + pending_comms.extend(observer.synchronize()) + modules_to_update.append((module, base_name, observer)) + + wait_for_comms(pending_comms) + + # recompute qparams from synchronized statistics + for module, base_name, observer in modules_to_update: + # recompute global scale if using TENSOR_GROUP strategy + global_scale = observer.recompute_global_scale() + if global_scale is not None: + update_offload_parameter( + module, f"{base_name}_global_scale", global_scale + ) + + result = observer.recompute_qparams() + if result is not None: + scale, zero_point = result + update_offload_parameter(module, f"{base_name}_scale", scale) + if hasattr(module, f"{base_name}_zero_point"): + update_offload_parameter( + module, f"{base_name}_zero_point", zero_point + ) + def has_config(self) -> bool: """ Determine if the user has specified a quantization config on this modifier diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 384bbf6ead..eb687803dd 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -1,13 +1,15 @@ from abc import abstractmethod -from typing import Optional, Tuple +from typing import List, Optional, Tuple from weakref import ref import torch from compressed_tensors import InternalModule +from compressed_tensors.offload.dist_utils import as_broadcastable from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam from compressed_tensors.registry.registry import RegistryMixin from compressed_tensors.utils import align_module_device +from torch import distributed as dist from llmcompressor.observers.helpers import flatten_for_calibration @@ -133,6 +135,65 @@ def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]: with align_module_device(module): return getattr(module, f"{self.base_name}_{name}", None) + def synchronize(self) -> List[dist.Work]: + """All-reduce accumulated min/max statistics across DDP ranks. + + Issues async all-reduce operations on any accumulated state + (``past_min_vals``, ``past_max_vals``, ``past_global_min_vals``, + ``past_global_max_vals``). Memoryless observers return an empty list. + + :return: list of async communication handles + """ + comms = [] + for attr, op in [ + ("past_min_vals", dist.ReduceOp.MIN), + ("past_max_vals", dist.ReduceOp.MAX), + ("past_global_min_vals", dist.ReduceOp.MIN), + ("past_global_max_vals", dist.ReduceOp.MAX), + ]: + val = getattr(self, attr, None) + if val is not None: + comms.append( + dist.all_reduce(as_broadcastable(val), op=op, async_op=True) + ) + return comms + + def recompute_global_scale(self) -> Optional[torch.Tensor]: + """Recompute global scale from accumulated global min/max state. + + Used after :meth:`synchronize` to update the global scale from + globally reduced statistics. Returns ``None`` for memoryless observers. + + :return: global scale tensor or ``None`` + """ + global_min = getattr(self, "past_global_min_vals", None) + global_max = getattr(self, "past_global_max_vals", None) + if global_min is None or global_max is None: + return None + return generate_gparam(global_min, global_max) + + def recompute_qparams(self) -> Optional[ScaleZpTuple]: + """Recompute scale and zero_point from accumulated min/max state. + + Used after :meth:`synchronize` to update quantization parameters from + globally reduced statistics. Returns ``None`` for memoryless observers. + + :return: (scale, zero_point) tuple or ``None`` + """ + min_vals = getattr(self, "past_min_vals", None) + max_vals = getattr(self, "past_max_vals", None) + if min_vals is None or max_vals is None: + return None + + global_scale = self._get_module_param("global_scale") + self._check_has_global_scale(global_scale) + return calculate_qparams( + min_vals=min_vals, + max_vals=max_vals, + quantization_args=self.args, + global_scale=global_scale, + ) + def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]): if ( self.args.strategy == QuantizationStrategy.TENSOR_GROUP diff --git a/src/llmcompressor/observers/moving_base.py b/src/llmcompressor/observers/moving_base.py index f94c474284..b587af33e4 100644 --- a/src/llmcompressor/observers/moving_base.py +++ b/src/llmcompressor/observers/moving_base.py @@ -1,8 +1,10 @@ from abc import abstractmethod -from typing import Optional +from typing import List, Optional import torch +from compressed_tensors.offload.dist_utils import as_broadcastable from compressed_tensors.quantization.quant_args import QuantizationArgs +from torch import distributed as dist from llmcompressor.observers.base import MinMaxTuple, Observer @@ -97,6 +99,33 @@ def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: return min_vals, max_vals + def synchronize(self) -> List[dist.Work]: + """Average accumulated moving-average min/max statistics across DDP ranks. + + Unlike :class:`StaticMinMaxObserver` which reduces via MIN/MAX, + moving-average observers divide by world_size first and then SUM + so that the result is the average across ranks. + + :return: list of async communication handles + """ + comms = [] + world_size = dist.get_world_size() + for attr in ( + "past_min_vals", + "past_max_vals", + "past_global_min_vals", + "past_global_max_vals", + ): + val = getattr(self, attr, None) + if val is not None: + val.div_(world_size) + comms.append( + dist.all_reduce( + as_broadcastable(val), op=dist.ReduceOp.AVG, async_op=True + ) + ) + return comms + def _lerp( self, input: torch.Tensor, end: torch.Tensor, weight: float ) -> torch.Tensor: diff --git a/tests/llmcompressor/modifiers/quantization/test_quantization_ddp.py b/tests/llmcompressor/modifiers/quantization/test_quantization_ddp.py new file mode 100644 index 0000000000..fd2f143220 --- /dev/null +++ b/tests/llmcompressor/modifiers/quantization/test_quantization_ddp.py @@ -0,0 +1,91 @@ +""" +Run with: torchrun --nproc_per_node=2 -m pytest -v +""" + +import os + +import pytest +import torch +import torch.distributed as dist +from compressed_tensors.quantization import QuantizationArgs + +from llmcompressor.observers.min_max import StaticMinMaxObserver +from llmcompressor.utils.dist import wait_for_comms +from tests.testing_utils import requires_gpu + +# initialize process group when running under torchrun +if ( + os.environ.get("RANK") is not None + and torch.cuda.is_available() + and not dist.is_initialized() +): + dist.init_process_group(backend="nccl") + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) + + +def _skip_if_not_distributed(): + if not (dist.is_available() and dist.is_initialized()): + pytest.skip("Requires torchrun --nproc_per_node=2") + + +@pytest.mark.multi_gpu +@requires_gpu(2) +def test_observer_synchronize_reduces_min_max(): + _skip_if_not_distributed() + rank = dist.get_rank() + + args = QuantizationArgs(num_bits=8, type="int", symmetric=True, strategy="tensor") + observer = StaticMinMaxObserver(base_name="input", args=args) + + # each rank has different local statistics + observer.past_min_vals = ( + torch.tensor([1.0, 3.0], device="cuda") + if rank == 0 + else torch.tensor([2.0, 1.0], device="cuda") + ) + observer.past_max_vals = ( + torch.tensor([10.0, 20.0], device="cuda") + if rank == 0 + else torch.tensor([15.0, 10.0], device="cuda") + ) + + comms = observer.synchronize() + wait_for_comms(comms) + + # after sync, min should be element-wise minimum, max element-wise maximum + assert torch.equal(observer.past_min_vals, torch.tensor([1.0, 1.0], device="cuda")) + assert torch.equal( + observer.past_max_vals, torch.tensor([15.0, 20.0], device="cuda") + ) + + +@pytest.mark.multi_gpu +@requires_gpu(2) +def test_synced_qparams_are_identical_across_ranks(): + _skip_if_not_distributed() + rank = dist.get_rank() + + args = QuantizationArgs(num_bits=8, type="int", symmetric=True, strategy="tensor") + observer = StaticMinMaxObserver(base_name="input", args=args) + + observer.past_min_vals = ( + torch.tensor([-2.0], device="cuda") + if rank == 0 + else torch.tensor([-5.0], device="cuda") + ) + observer.past_max_vals = ( + torch.tensor([3.0], device="cuda") + if rank == 0 + else torch.tensor([1.0], device="cuda") + ) + + comms = observer.synchronize() + wait_for_comms(comms) + + result = observer.recompute_qparams() + assert result is not None + scale, _ = result + + gathered = [torch.zeros_like(scale) for _ in range(dist.get_world_size())] + dist.all_gather(gathered, scale) + assert torch.equal(gathered[0], gathered[1]) diff --git a/tests/llmcompressor/utils/test_distributed.py b/tests/llmcompressor/utils/test_distributed.py new file mode 100644 index 0000000000..c2efd03b21 --- /dev/null +++ b/tests/llmcompressor/utils/test_distributed.py @@ -0,0 +1,125 @@ +from unittest.mock import MagicMock, patch + +import pytest +import torch +from compressed_tensors.quantization import QuantizationArgs + +from llmcompressor.observers.min_max import ( + MemorylessMinMaxObserver, + MinMaxObserver, + StaticMinMaxObserver, +) + + +def _make_observer(cls, **kwargs): + args = QuantizationArgs(num_bits=8, type="int", symmetric=True, strategy="tensor") + return cls(base_name="input", args=args, **kwargs) + + +@pytest.mark.unit +def test_memoryless_synchronize_returns_empty(): + observer = _make_observer(MemorylessMinMaxObserver) + assert observer.synchronize() == [] + + +@pytest.mark.unit +def test_memoryless_recompute_returns_none(): + observer = _make_observer(MemorylessMinMaxObserver) + assert observer.recompute_qparams() is None + assert observer.recompute_global_scale() is None + + +@pytest.mark.unit +def test_static_synchronize_returns_empty_before_observation(): + observer = _make_observer(StaticMinMaxObserver) + assert observer.synchronize() == [] + + +@pytest.mark.unit +@patch("llmcompressor.observers.base.dist") +def test_static_synchronize_issues_all_reduce(mock_dist): + mock_dist.ReduceOp.MIN = "MIN" + mock_dist.ReduceOp.MAX = "MAX" + mock_dist.all_reduce.return_value = MagicMock() + + observer = _make_observer(StaticMinMaxObserver) + observer.past_min_vals = torch.tensor([-1.0]) + observer.past_max_vals = torch.tensor([1.0]) + + comms = observer.synchronize() + assert len(comms) == 2 + assert mock_dist.all_reduce.call_count == 2 + + # verify correct ops + calls = mock_dist.all_reduce.call_args_list + assert calls[0].kwargs["op"] == "MIN" + assert calls[1].kwargs["op"] == "MAX" + + +@pytest.mark.unit +@patch("llmcompressor.observers.base.dist") +def test_static_synchronize_with_global_state(mock_dist): + mock_dist.ReduceOp.MIN = "MIN" + mock_dist.ReduceOp.MAX = "MAX" + mock_dist.all_reduce.return_value = MagicMock() + + observer = _make_observer(StaticMinMaxObserver) + observer.past_min_vals = torch.tensor([-1.0]) + observer.past_max_vals = torch.tensor([1.0]) + observer.past_global_min_vals = torch.tensor([-2.0]) + observer.past_global_max_vals = torch.tensor([2.0]) + + comms = observer.synchronize() + assert len(comms) == 4 + assert mock_dist.all_reduce.call_count == 4 + + +@pytest.mark.unit +@patch("llmcompressor.observers.moving_base.dist") +def test_moving_avg_synchronize_issues_all_reduce(mock_dist): + mock_dist.ReduceOp.SUM = "SUM" + mock_dist.get_world_size.return_value = 2 + mock_dist.all_reduce.return_value = MagicMock() + + observer = _make_observer(MinMaxObserver) + observer.past_min_vals = torch.tensor([-1.0]) + observer.past_max_vals = torch.tensor([1.0]) + + comms = observer.synchronize() + assert len(comms) == 2 + + +@pytest.mark.unit +def test_recompute_qparams_from_accumulated_state(): + observer = _make_observer(StaticMinMaxObserver) + observer.past_min_vals = torch.tensor([-5.0]) + observer.past_max_vals = torch.tensor([5.0]) + + result = observer.recompute_qparams() + assert result is not None + scale, zero_point = result + assert scale.numel() > 0 + assert zero_point.numel() > 0 + + +@pytest.mark.unit +def test_recompute_qparams_returns_none_without_state(): + observer = _make_observer(StaticMinMaxObserver) + assert observer.recompute_qparams() is None + + +@pytest.mark.unit +def test_recompute_global_scale_returns_none_without_state(): + observer = _make_observer(StaticMinMaxObserver) + assert observer.recompute_global_scale() is None + + +@pytest.mark.unit +def test_recompute_global_scale_from_accumulated_state(): + observer = _make_observer(StaticMinMaxObserver) + observer.past_global_min_vals = torch.tensor([-10.0]) + observer.past_global_max_vals = torch.tensor([10.0]) + + result = observer.recompute_global_scale() + assert result is not None + assert result.numel() > 0