diff --git a/tests/unit_tests/test_model_converter.py b/tests/unit_tests/test_model_converter.py index 5e37ad3a40..e2dca74c37 100644 --- a/tests/unit_tests/test_model_converter.py +++ b/tests/unit_tests/test_model_converter.py @@ -11,6 +11,7 @@ from torchtitan.components.lora import LoRAConverter from torchtitan.components.quantization.float8 import Float8LinearConverter +from torchtitan.components.quantization.qat import QATConverter from torchtitan.config import ConfigManager from torchtitan.distributed import ParallelDims from torchtitan.protocols.model_converter import ModelConvertersContainer @@ -202,3 +203,129 @@ def test_lora_key_remap_roundtrip(): assert set(rt_sd.keys()) == set(tt_sd.keys()) for k in tt_sd: assert torch.equal(rt_sd[k], tt_sd[k]) + + +@pytest.mark.parametrize( + "scheme, group_size, expected_linear_cls", + [ + ("int4_weight_only", 64, "FakeQuantizedLinear"), + ("intx_weight_only", 64, "FakeQuantizedLinear"), + ("int8_dynamic_act_intx_weight", 64, "FakeQuantizedLinear"), + ("float8_dynamic_act_float8_weight", None, "FakeQuantizedLinear"), + ("float8_dynamic_act_int4_weight", None, "FakeQuantizedLinear"), + ("nvfp4", None, "NVFP4FakeQuantizedLinear"), + ("mx", None, "MXFakeQuantizedLinear"), + ], +) +def test_qat_all_schemes(scheme, group_size, expected_linear_cls): + """Each QAT scheme should replace nn.Linear with the correct fake-quantized + class and preserve weight dtype (fake quantization happens in forward).""" + pytest.importorskip("torchao") + + model = nn.Sequential( + OrderedDict( + [ + ("fc1", nn.Linear(64, 64)), + ("relu", nn.ReLU()), + ("fc2", nn.Linear(64, 64)), + ] + ) + ) + original_dtypes = {name: param.dtype for name, param in model.named_parameters()} + + config_kwargs = {"scheme": scheme} + if group_size is not None: + config_kwargs["group_size"] = group_size + converter = QATConverter(QATConverter.Config(**config_kwargs)) + converter.convert(model) + + # Linear layers should be replaced with the expected class + assert ( + type(model.fc1).__name__ == expected_linear_cls + ), f"scheme={scheme}: expected {expected_linear_cls}, got {type(model.fc1).__name__}" + assert ( + type(model.fc2).__name__ == expected_linear_cls + ), f"scheme={scheme}: expected {expected_linear_cls}, got {type(model.fc2).__name__}" + + # Weight dtype should be preserved + for name, param in model.named_parameters(): + assert ( + param.dtype == original_dtypes[name] + ), f"'{name}' dtype changed from {original_dtypes[name]} to {param.dtype}" + + +def test_qat_unknown_scheme_raises(): + """QATConverter should raise ValueError for unknown schemes.""" + with pytest.raises(ValueError, match="Unknown QAT scheme"): + QATConverter(QATConverter.Config(scheme="not_a_real_scheme")) + + +def test_qat_group_size_warning_for_unsupported_scheme(caplog): + """QATConverter should warn when group_size is set for a scheme that ignores it.""" + pytest.importorskip("torchao") + import logging + + with caplog.at_level(logging.WARNING): + QATConverter( + QATConverter.Config( + scheme="float8_dynamic_act_float8_weight", group_size=64 + ) + ) + assert "does not use group_size" in caplog.text + + +def test_qat_lora_adapter_qat(): + """QAT + LoRA: base and adapter weights are both fake-quantized. + Also tests that group_size > rank errors out.""" + pytest.importorskip("torchao") + from torchao.quantization.qat.linear import FakeQuantizedLinear + + # --- group_size > rank should error --- + model = nn.Sequential( + OrderedDict( + [ + ("fc1", nn.Linear(128, 128)), + ("relu", nn.ReLU()), + ("fc2", nn.Linear(128, 128)), + ] + ) + ) + qat_converter = QATConverter( + QATConverter.Config(scheme="intx_weight_only", group_size=128) + ) + qat_converter.convert(model) + with pytest.raises(ValueError, match="does not divide LoRA rank"): + LoRAConverter(LoRAConverter.Config(rank=8, alpha=16.0)).convert(model) + + # --- Compatible group_size should work --- + model = nn.Sequential( + OrderedDict( + [ + ("fc1", nn.Linear(64, 64)), + ("relu", nn.ReLU()), + ("fc2", nn.Linear(64, 64)), + ] + ) + ) + qat_converter = QATConverter( + QATConverter.Config(scheme="intx_weight_only", group_size=8) + ) + qat_converter.convert(model) + + assert isinstance(model.fc1, FakeQuantizedLinear) + + lora_converter = LoRAConverter(LoRAConverter.Config(rank=8, alpha=16.0)) + lora_converter.convert(model) + + # Base linears are LoRA-wrapped FakeQuantizedLinear + assert isinstance(model.fc1, FakeQuantizedLinear) + # Adapter linears are also FakeQuantizedLinear + assert isinstance(model.fc1.lora_a, FakeQuantizedLinear) + assert isinstance(model.fc1.lora_b, FakeQuantizedLinear) + + # Forward pass should succeed + x = torch.randn(4, 64) + out = model(x) + assert out.shape == (4, 64) + + diff --git a/torchtitan/components/lora.py b/torchtitan/components/lora.py index 93136b1194..1958a868df 100644 --- a/torchtitan/components/lora.py +++ b/torchtitan/components/lora.py @@ -107,6 +107,7 @@ def __init__(self, config: Config, **kwargs): f"LoRA save_format must be 'dcp', 'peft', or 'merged', " f"got '{self.save_format}'" ) + logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}") @staticmethod @@ -148,6 +149,14 @@ def convert(self, model: nn.Module) -> None: model.requires_grad_(False) self._replace_linears_with_lora(model) + # If QATConverter was applied before LoRA, apply the same QAT to + # the newly created adapter linears. QATConverter stores its config + # on the model as _qat_scheme / _qat_group_size. + qat_scheme = getattr(model, "_qat_scheme", None) + if qat_scheme is not None: + qat_group_size = getattr(model, "_qat_group_size", 128) + self._apply_adapter_qat(model, qat_scheme, qat_group_size) + # Wire up checkpoint filtering so ModelWrapper knows which keys # are adapter keys and how to save them. model.converter_key_filter = self._is_lora_key # type: ignore[attr-defined] @@ -160,6 +169,41 @@ def convert(self, model: nn.Module) -> None: if self.save_format == "merged": model.converter_export_sd_fn = self._make_merge_fn() # type: ignore[attr-defined] + def _apply_adapter_qat( + self, model: nn.Module, scheme: str, group_size: int + ) -> None: + from torchtitan.components.quantization.qat import _SCHEMES_WITH_GROUP_SIZE + + # Validate group_size against LoRA rank + if scheme in _SCHEMES_WITH_GROUP_SIZE and self.rank % group_size != 0: + raise ValueError( + f"QAT group_size ({group_size}) does not divide LoRA rank " + f"({self.rank}). Use a smaller group_size or larger rank." + ) + + from torchao.quantization import quantize_ + from torchao.quantization.qat import QATConfig + from torchao.quantization.qat.api import QATStep + + from torchtitan.components.quantization.qat import _build_base_config + + base_config = _build_base_config(scheme, group_size) + + def _is_lora_linear(mod: nn.Module, fqn: str) -> bool: + return isinstance(mod, nn.Linear) and ( + fqn.endswith(".lora_a") or fqn.endswith(".lora_b") + ) + + quantize_( + model, + QATConfig(base_config, step=QATStep.PREPARE), + filter_fn=_is_lora_linear, + ) + logger.info( + f"Applied adapter QAT fake quantization " + f"(scheme={scheme}, group_size={group_size})" + ) + def _replace_linears_with_lora(self, module: nn.Module) -> None: for _, child in list(module.named_modules()): if isinstance(child, nn.Linear): diff --git a/torchtitan/components/quantization/qat.py b/torchtitan/components/quantization/qat.py new file mode 100644 index 0000000000..be52646427 --- /dev/null +++ b/torchtitan/components/quantization/qat.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +import torch.nn as nn +from torchtitan.config import Configurable +from torchtitan.tools.logging import logger + +# Supported scheme names. +_SUPPORTED_SCHEMES = ( + "int4_weight_only", + "intx_weight_only", + "int8_dynamic_act_intx_weight", + "float8_dynamic_act_float8_weight", + "float8_dynamic_act_int4_weight", + "nvfp4", + "mx", +) + +# Schemes that accept a group_size parameter. +_SCHEMES_WITH_GROUP_SIZE = ( + "int4_weight_only", + "intx_weight_only", + "int8_dynamic_act_intx_weight", +) + + +def _build_base_config(scheme: str, group_size: int): + """Return a torchao PTQ base config for the given scheme name.""" + if scheme == "int4_weight_only": + from torchao.quantization import Int4WeightOnlyConfig + + return Int4WeightOnlyConfig(group_size=group_size) + + elif scheme == "intx_weight_only": + import torch + from torchao.quantization import IntxWeightOnlyConfig + from torchao.quantization.granularity import PerGroup + + int4_dtype = torch.int4 # pyrefly: ignore[missing-attribute] + return IntxWeightOnlyConfig( + weight_dtype=int4_dtype, + granularity=PerGroup(group_size), + ) + + elif scheme == "int8_dynamic_act_intx_weight": + import torch + from torchao.quantization import Int8DynamicActivationIntxWeightConfig + from torchao.quantization.granularity import PerGroup + + int4_dtype = torch.int4 # pyrefly: ignore[missing-attribute] + return Int8DynamicActivationIntxWeightConfig( + weight_dtype=int4_dtype, + weight_granularity=PerGroup(group_size), + ) + + elif scheme == "float8_dynamic_act_float8_weight": + from torchao.quantization import Float8DynamicActivationFloat8WeightConfig + + return Float8DynamicActivationFloat8WeightConfig() + + elif scheme == "float8_dynamic_act_int4_weight": + from torchao.quantization import Float8DynamicActivationInt4WeightConfig + + return Float8DynamicActivationInt4WeightConfig() + + elif scheme == "nvfp4": + from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig + + return NVFP4DynamicActivationNVFP4WeightConfig() + + elif scheme == "mx": + from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig + + return MXDynamicActivationMXWeightConfig() + + else: + raise ValueError( + f"Unknown QAT scheme '{scheme}'. Supported: {_SUPPORTED_SCHEMES}" + ) + + +class QATConverter(Configurable): + """Apply quantization-aware training via torchao's QATConfig. + + Uses ``torchao.quantize_(model, QATConfig(base_config, step="prepare"))`` + to insert fake quantization into ``nn.Linear`` modules. The ``scheme`` + config field selects a torchao PTQ base config, which QATConfig uses to + infer the appropriate fake quantization for both weights and activations. + + Supported schemes: + - ``"int4_weight_only"`` — int4 weight-only fake quantization + - ``"intx_weight_only"`` — intx weight-only fake quantization + - ``"int8_dynamic_act_intx_weight"`` — int8 activation + int4 weight + - ``"float8_dynamic_act_float8_weight"`` — float8 activation + float8 weight + - ``"float8_dynamic_act_int4_weight"`` — float8 activation + int4 weight + - ``"nvfp4"`` — NVFP4 dynamic activation + NVFP4 weight + - ``"mx"`` — MX dynamic activation + MX weight + + When composed with LoRA (QATConverter listed before LoRAConverter in converters), + LoRA will inherit from FakeQuantizedLinear so base weights are fake-quantized + while LoRA adapters stay full-precision. + """ + + @dataclass(kw_only=True, slots=True) + class Config(Configurable.Config): + scheme: str = "int4_weight_only" + """QAT scheme name. Maps to a torchao PTQ base config. + Supported: 'int4_weight_only', 'intx_weight_only', + 'int8_dynamic_act_intx_weight', 'float8_dynamic_act_float8_weight', + 'float8_dynamic_act_int4_weight', 'nvfp4', 'mx'.""" + + group_size: int = 128 + """Group size for per-group weight quantization. + Used by schemes that support per-group granularity + (int4_weight_only, intx_weight_only, int8_dynamic_act_intx_weight). + Must divide in_features of all Linear layers in the model.""" + + def __init__(self, config: Config, **kwargs): + if config.scheme not in _SUPPORTED_SCHEMES: + raise ValueError( + f"Unknown QAT scheme '{config.scheme}'. " + f"Supported: {_SUPPORTED_SCHEMES}" + ) + self.scheme = config.scheme + self.group_size = config.group_size + if config.scheme not in _SCHEMES_WITH_GROUP_SIZE: + logger.warning( + f"QAT scheme '{config.scheme}' does not use group_size, " + f"ignoring group_size={config.group_size}" + ) + logger.info( + f"QAT training active (scheme={self.scheme}, group_size={self.group_size})" + ) + + def convert(self, model: nn.Module) -> None: + from torchao.quantization import quantize_ + from torchao.quantization.qat import QATConfig + from torchao.quantization.qat.api import QATStep + + base_config = _build_base_config(self.scheme, self.group_size) + quantize_(model, QATConfig(base_config, step=QATStep.PREPARE)) + + # Store QAT config on the model so downstream converters (e.g. LoRA) + # can apply the same QAT to newly created modules. + model._qat_scheme = self.scheme # type: ignore[attr-defined] + model._qat_group_size = self.group_size # type: ignore[attr-defined] + + logger.info( + f"Applied QAT fake quantization (scheme={self.scheme}, " + f"group_size={self.group_size})" + ) + + def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None: + pass diff --git a/torchtitan/models/llama3/config_registry.py b/torchtitan/models/llama3/config_registry.py index 5b06bbeea8..ca4eaf61ed 100644 --- a/torchtitan/models/llama3/config_registry.py +++ b/torchtitan/models/llama3/config_registry.py @@ -13,6 +13,7 @@ OptimizersInBackwardContainer, ) from torchtitan.components.quantization.float8 import Float8LinearConverter +from torchtitan.components.quantization.qat import QATConverter from torchtitan.components.validate import Validator from torchtitan.config import ( ActivationCheckpointConfig, @@ -130,6 +131,59 @@ def llama3_debugmodel_lora() -> Trainer.Config: return config +def llama3_debugmodel_qat() -> Trainer.Config: + config = llama3_debugmodel() + config.model_converters = ModelConvertersContainer.Config( + converters=[ + QATConverter.Config(), + ], + ) + return config + + +def llama3_debugmodel_qat_lora() -> Trainer.Config: + config = llama3_debugmodel() + # QATConverter must come before LoRAConverter so that LoRA inherits from + # FakeQuantizedLinear, giving fake-quantized base weights + full-precision adapters. + config.model_converters = ModelConvertersContainer.Config( + converters=[ + QATConverter.Config(), + LoRAConverter.Config( + rank=8, + alpha=16.0, + ), + ], + ) + return config + + +def llama3_debugmodel_qat_lora_merged() -> Trainer.Config: + """QAT + LoRA with merged save — QAT is applied to both base weights and + LoRA adapters, then adapters are folded into base weights at save time.""" + config = llama3_debugmodel() + config.model_converters = ModelConvertersContainer.Config( + converters=[ + QATConverter.Config(scheme="intx_weight_only", group_size=8), + LoRAConverter.Config( + rank=8, + alpha=16.0, + save_format="merged", + ), + ], + ) + config.checkpoint = CheckpointManager.Config( + enable=True, + interval=5, + last_save_model_only=True, + ) + config.training = TrainingConfig( + local_batch_size=8, + seq_len=2048, + steps=20, + ) + return config + + def llama3_8b() -> Trainer.Config: return Trainer.Config( hf_assets_path="./assets/hf/Llama-3.1-8B", diff --git a/torchtitan/protocols/model_converter.py b/torchtitan/protocols/model_converter.py index 6cef857cec..1dc34cef9c 100644 --- a/torchtitan/protocols/model_converter.py +++ b/torchtitan/protocols/model_converter.py @@ -87,20 +87,21 @@ def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): def _validate_converter_ordering(converters: list[Configurable.Config]): """Validates that converters are in the correct order. - LoRA must come after quantization because quantization replaces nn.Linear - with specialized subclasses (e.g. Float8Linear), and LoRA dynamically - inherits from whatever linear class it wraps. + LoRA must come after quantization and QAT because both replace nn.Linear + with specialized subclasses (e.g. Float8Linear, FakeQuantizedLinear), and + LoRA dynamically inherits from whatever linear class it wraps. """ from torchtitan.components.lora import LoRAConverter + from torchtitan.components.quantization.qat import QATConverter seen_lora = False for config in converters: if isinstance(config, LoRAConverter.Config): seen_lora = True - elif isinstance(config, QuantizationConverter.Config) and seen_lora: + elif isinstance(config, (QuantizationConverter.Config, QATConverter.Config)) and seen_lora: raise ValueError( - "LoRA converter must come after quantization converters. " - "Quantization replaces nn.Linear with specialized subclasses, " + "LoRA converter must come after quantization and QAT converters. " + "Quantization/QAT replaces nn.Linear with specialized subclasses, " "and LoRA must wrap the final linear class." )