Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
901c97a
Add QAT (quantization-aware training) model converter
mori360 Mar 5, 2026
bcc2b2c
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 5, 2026
858b1d4
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 5, 2026
828ee0c
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 5, 2026
0728515
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 6, 2026
8beee82
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 6, 2026
a248b79
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 12, 2026
3ee65be
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 12, 2026
59029b2
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 12, 2026
27c4951
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 12, 2026
4300ee6
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 13, 2026
1a2826f
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 13, 2026
1206b0b
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 17, 2026
c1ad90a
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 17, 2026
afa28f0
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 17, 2026
5e623d1
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 17, 2026
5a39621
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions tests/unit_tests/test_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from torchtitan.components.lora import LoRAConverter
from torchtitan.components.quantization.float8 import Float8LinearConverter
from torchtitan.components.quantization.qat import QATConverter
from torchtitan.config import ConfigManager
from torchtitan.distributed import ParallelDims
from torchtitan.protocols.model_converter import ModelConvertersContainer
Expand Down Expand Up @@ -202,3 +203,129 @@ def test_lora_key_remap_roundtrip():
assert set(rt_sd.keys()) == set(tt_sd.keys())
for k in tt_sd:
assert torch.equal(rt_sd[k], tt_sd[k])


@pytest.mark.parametrize(
"scheme, group_size, expected_linear_cls",
[
("int4_weight_only", 64, "FakeQuantizedLinear"),
("intx_weight_only", 64, "FakeQuantizedLinear"),
("int8_dynamic_act_intx_weight", 64, "FakeQuantizedLinear"),
("float8_dynamic_act_float8_weight", None, "FakeQuantizedLinear"),
("float8_dynamic_act_int4_weight", None, "FakeQuantizedLinear"),
("nvfp4", None, "NVFP4FakeQuantizedLinear"),
("mx", None, "MXFakeQuantizedLinear"),
],
)
def test_qat_all_schemes(scheme, group_size, expected_linear_cls):
"""Each QAT scheme should replace nn.Linear with the correct fake-quantized
class and preserve weight dtype (fake quantization happens in forward)."""
pytest.importorskip("torchao")

model = nn.Sequential(
OrderedDict(
[
("fc1", nn.Linear(64, 64)),
("relu", nn.ReLU()),
("fc2", nn.Linear(64, 64)),
]
)
)
original_dtypes = {name: param.dtype for name, param in model.named_parameters()}

config_kwargs = {"scheme": scheme}
if group_size is not None:
config_kwargs["group_size"] = group_size
converter = QATConverter(QATConverter.Config(**config_kwargs))
converter.convert(model)

# Linear layers should be replaced with the expected class
assert (
type(model.fc1).__name__ == expected_linear_cls
), f"scheme={scheme}: expected {expected_linear_cls}, got {type(model.fc1).__name__}"
assert (
type(model.fc2).__name__ == expected_linear_cls
), f"scheme={scheme}: expected {expected_linear_cls}, got {type(model.fc2).__name__}"

# Weight dtype should be preserved
for name, param in model.named_parameters():
assert (
param.dtype == original_dtypes[name]
), f"'{name}' dtype changed from {original_dtypes[name]} to {param.dtype}"


def test_qat_unknown_scheme_raises():
"""QATConverter should raise ValueError for unknown schemes."""
with pytest.raises(ValueError, match="Unknown QAT scheme"):
QATConverter(QATConverter.Config(scheme="not_a_real_scheme"))


def test_qat_group_size_warning_for_unsupported_scheme(caplog):
"""QATConverter should warn when group_size is set for a scheme that ignores it."""
pytest.importorskip("torchao")
import logging

with caplog.at_level(logging.WARNING):
QATConverter(
QATConverter.Config(
scheme="float8_dynamic_act_float8_weight", group_size=64
)
)
assert "does not use group_size" in caplog.text


def test_qat_lora_adapter_qat():
"""QAT + LoRA: base and adapter weights are both fake-quantized.
Also tests that group_size > rank errors out."""
pytest.importorskip("torchao")
from torchao.quantization.qat.linear import FakeQuantizedLinear

# --- group_size > rank should error ---
model = nn.Sequential(
OrderedDict(
[
("fc1", nn.Linear(128, 128)),
("relu", nn.ReLU()),
("fc2", nn.Linear(128, 128)),
]
)
)
qat_converter = QATConverter(
QATConverter.Config(scheme="intx_weight_only", group_size=128)
)
qat_converter.convert(model)
with pytest.raises(ValueError, match="does not divide LoRA rank"):
LoRAConverter(LoRAConverter.Config(rank=8, alpha=16.0)).convert(model)

# --- Compatible group_size should work ---
model = nn.Sequential(
OrderedDict(
[
("fc1", nn.Linear(64, 64)),
("relu", nn.ReLU()),
("fc2", nn.Linear(64, 64)),
]
)
)
qat_converter = QATConverter(
QATConverter.Config(scheme="intx_weight_only", group_size=8)
)
qat_converter.convert(model)

assert isinstance(model.fc1, FakeQuantizedLinear)

lora_converter = LoRAConverter(LoRAConverter.Config(rank=8, alpha=16.0))
lora_converter.convert(model)

# Base linears are LoRA-wrapped FakeQuantizedLinear
assert isinstance(model.fc1, FakeQuantizedLinear)
# Adapter linears are also FakeQuantizedLinear
assert isinstance(model.fc1.lora_a, FakeQuantizedLinear)
assert isinstance(model.fc1.lora_b, FakeQuantizedLinear)

# Forward pass should succeed
x = torch.randn(4, 64)
out = model(x)
assert out.shape == (4, 64)


44 changes: 44 additions & 0 deletions torchtitan/components/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(self, config: Config, **kwargs):
f"LoRA save_format must be 'dcp', 'peft', or 'merged', "
f"got '{self.save_format}'"
)

logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")

@staticmethod
Expand Down Expand Up @@ -148,6 +149,14 @@ def convert(self, model: nn.Module) -> None:
model.requires_grad_(False)
self._replace_linears_with_lora(model)

# If QATConverter was applied before LoRA, apply the same QAT to
# the newly created adapter linears. QATConverter stores its config
# on the model as _qat_scheme / _qat_group_size.
qat_scheme = getattr(model, "_qat_scheme", None)
if qat_scheme is not None:
qat_group_size = getattr(model, "_qat_group_size", 128)
self._apply_adapter_qat(model, qat_scheme, qat_group_size)

# Wire up checkpoint filtering so ModelWrapper knows which keys
# are adapter keys and how to save them.
model.converter_key_filter = self._is_lora_key # type: ignore[attr-defined]
Expand All @@ -160,6 +169,41 @@ def convert(self, model: nn.Module) -> None:
if self.save_format == "merged":
model.converter_export_sd_fn = self._make_merge_fn() # type: ignore[attr-defined]

def _apply_adapter_qat(
self, model: nn.Module, scheme: str, group_size: int
) -> None:
from torchtitan.components.quantization.qat import _SCHEMES_WITH_GROUP_SIZE

# Validate group_size against LoRA rank
if scheme in _SCHEMES_WITH_GROUP_SIZE and self.rank % group_size != 0:
raise ValueError(
f"QAT group_size ({group_size}) does not divide LoRA rank "
f"({self.rank}). Use a smaller group_size or larger rank."
)

from torchao.quantization import quantize_
from torchao.quantization.qat import QATConfig
from torchao.quantization.qat.api import QATStep

from torchtitan.components.quantization.qat import _build_base_config

base_config = _build_base_config(scheme, group_size)

def _is_lora_linear(mod: nn.Module, fqn: str) -> bool:
return isinstance(mod, nn.Linear) and (
fqn.endswith(".lora_a") or fqn.endswith(".lora_b")
)

quantize_(
model,
QATConfig(base_config, step=QATStep.PREPARE),
filter_fn=_is_lora_linear,
)
logger.info(
f"Applied adapter QAT fake quantization "
f"(scheme={scheme}, group_size={group_size})"
)

def _replace_linears_with_lora(self, module: nn.Module) -> None:
for _, child in list(module.named_modules()):
if isinstance(child, nn.Linear):
Expand Down
159 changes: 159 additions & 0 deletions torchtitan/components/quantization/qat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass

import torch.nn as nn
from torchtitan.config import Configurable
from torchtitan.tools.logging import logger

# Supported scheme names.
_SUPPORTED_SCHEMES = (
"int4_weight_only",
"intx_weight_only",
"int8_dynamic_act_intx_weight",
"float8_dynamic_act_float8_weight",
"float8_dynamic_act_int4_weight",
"nvfp4",
"mx",
)

# Schemes that accept a group_size parameter.
_SCHEMES_WITH_GROUP_SIZE = (
"int4_weight_only",
"intx_weight_only",
"int8_dynamic_act_intx_weight",
)


def _build_base_config(scheme: str, group_size: int):
"""Return a torchao PTQ base config for the given scheme name."""
if scheme == "int4_weight_only":
from torchao.quantization import Int4WeightOnlyConfig

return Int4WeightOnlyConfig(group_size=group_size)

elif scheme == "intx_weight_only":
import torch
from torchao.quantization import IntxWeightOnlyConfig
from torchao.quantization.granularity import PerGroup

int4_dtype = torch.int4 # pyrefly: ignore[missing-attribute]
return IntxWeightOnlyConfig(
weight_dtype=int4_dtype,
granularity=PerGroup(group_size),
)

elif scheme == "int8_dynamic_act_intx_weight":
import torch
from torchao.quantization import Int8DynamicActivationIntxWeightConfig
from torchao.quantization.granularity import PerGroup

int4_dtype = torch.int4 # pyrefly: ignore[missing-attribute]
return Int8DynamicActivationIntxWeightConfig(
weight_dtype=int4_dtype,
weight_granularity=PerGroup(group_size),
)

elif scheme == "float8_dynamic_act_float8_weight":
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig

return Float8DynamicActivationFloat8WeightConfig()

elif scheme == "float8_dynamic_act_int4_weight":
from torchao.quantization import Float8DynamicActivationInt4WeightConfig

return Float8DynamicActivationInt4WeightConfig()

elif scheme == "nvfp4":
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig

return NVFP4DynamicActivationNVFP4WeightConfig()

elif scheme == "mx":
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig

return MXDynamicActivationMXWeightConfig()

else:
raise ValueError(
f"Unknown QAT scheme '{scheme}'. Supported: {_SUPPORTED_SCHEMES}"
)


class QATConverter(Configurable):
"""Apply quantization-aware training via torchao's QATConfig.

Uses ``torchao.quantize_(model, QATConfig(base_config, step="prepare"))``
to insert fake quantization into ``nn.Linear`` modules. The ``scheme``
config field selects a torchao PTQ base config, which QATConfig uses to
infer the appropriate fake quantization for both weights and activations.

Supported schemes:
- ``"int4_weight_only"`` — int4 weight-only fake quantization
- ``"intx_weight_only"`` — intx weight-only fake quantization
- ``"int8_dynamic_act_intx_weight"`` — int8 activation + int4 weight
- ``"float8_dynamic_act_float8_weight"`` — float8 activation + float8 weight
- ``"float8_dynamic_act_int4_weight"`` — float8 activation + int4 weight
- ``"nvfp4"`` — NVFP4 dynamic activation + NVFP4 weight
- ``"mx"`` — MX dynamic activation + MX weight

When composed with LoRA (QATConverter listed before LoRAConverter in converters),
LoRA will inherit from FakeQuantizedLinear so base weights are fake-quantized
while LoRA adapters stay full-precision.
"""

@dataclass(kw_only=True, slots=True)
class Config(Configurable.Config):
scheme: str = "int4_weight_only"
"""QAT scheme name. Maps to a torchao PTQ base config.
Supported: 'int4_weight_only', 'intx_weight_only',
'int8_dynamic_act_intx_weight', 'float8_dynamic_act_float8_weight',
'float8_dynamic_act_int4_weight', 'nvfp4', 'mx'."""

group_size: int = 128
"""Group size for per-group weight quantization.
Used by schemes that support per-group granularity
(int4_weight_only, intx_weight_only, int8_dynamic_act_intx_weight).
Must divide in_features of all Linear layers in the model."""

def __init__(self, config: Config, **kwargs):
if config.scheme not in _SUPPORTED_SCHEMES:
raise ValueError(
f"Unknown QAT scheme '{config.scheme}'. "
f"Supported: {_SUPPORTED_SCHEMES}"
)
self.scheme = config.scheme
self.group_size = config.group_size
if config.scheme not in _SCHEMES_WITH_GROUP_SIZE:
logger.warning(
f"QAT scheme '{config.scheme}' does not use group_size, "
f"ignoring group_size={config.group_size}"
)
logger.info(
f"QAT training active (scheme={self.scheme}, group_size={self.group_size})"
)

def convert(self, model: nn.Module) -> None:
from torchao.quantization import quantize_
from torchao.quantization.qat import QATConfig
from torchao.quantization.qat.api import QATStep

base_config = _build_base_config(self.scheme, self.group_size)
quantize_(model, QATConfig(base_config, step=QATStep.PREPARE))

# Store QAT config on the model so downstream converters (e.g. LoRA)
# can apply the same QAT to newly created modules.
model._qat_scheme = self.scheme # type: ignore[attr-defined]
model._qat_group_size = self.group_size # type: ignore[attr-defined]

logger.info(
f"Applied QAT fake quantization (scheme={self.scheme}, "
f"group_size={self.group_size})"
)

def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
pass
Loading
Loading