Skip to content
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
901c97a
Add QAT (quantization-aware training) model converter
mori360 Mar 5, 2026
bcc2b2c
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 5, 2026
858b1d4
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 5, 2026
828ee0c
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 5, 2026
0728515
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 6, 2026
8beee82
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 6, 2026
a248b79
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 12, 2026
3ee65be
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 12, 2026
59029b2
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 12, 2026
27c4951
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 12, 2026
4300ee6
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 13, 2026
1a2826f
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 13, 2026
1206b0b
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 17, 2026
c1ad90a
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 17, 2026
afa28f0
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 17, 2026
5e623d1
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 17, 2026
5a39621
Update on "[4/N] Add QAT (quantization-aware training) model converter"
mori360 Mar 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/unit_tests/test_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from torchtitan.components.lora import LoRAConverter
from torchtitan.components.quantization.float8 import Float8LinearConverter
from torchtitan.components.quantization.qat import QATConverter
from torchtitan.config import ConfigManager
from torchtitan.distributed import ParallelDims
from torchtitan.protocols.model_converter import ModelConvertersContainer
Expand Down Expand Up @@ -214,3 +215,27 @@ def test_qlora_base_weights_quantized_adapters_full_precision():
assert (
layer.lora_b.weight.dtype == torch.float32
), f"{name}.lora_b.weight should be float32"


def test_qat_preserves_weight_dtype():
"""QAT converter should not change weight dtype (fake quantization happens in forward)."""
pytest.importorskip("torchao")

model = nn.Sequential(
OrderedDict(
[
("fc1", nn.Linear(64, 64)),
("relu", nn.ReLU()),
("fc2", nn.Linear(64, 64)),
]
)
)
original_dtypes = {name: param.dtype for name, param in model.named_parameters()}

converter = QATConverter(QATConverter.Config(group_size=64))
converter.convert(model)

for name, param in model.named_parameters():
assert (
param.dtype == original_dtypes[name]
), f"'{name}' dtype changed from {original_dtypes[name]} to {param.dtype}"
77 changes: 77 additions & 0 deletions torchtitan/components/quantization/qat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Literal

import torch
import torch.nn as nn
from torchtitan.config import Configurable
from torchtitan.tools.logging import logger


class QATConverter(Configurable):
"""Replace nn.Linear with FakeQuantizedLinear for quantization-aware training.

Uses torchao's FakeQuantizedLinear to simulate int4 weight quantization during
training. The fake quantization is applied in the forward pass so the model
learns to compensate for quantization error.

When composed with LoRA (QATConverter listed before LoRAConverter in converters),
LoRA will inherit from FakeQuantizedLinear so base weights are fake-quantized
while LoRA adapters stay full-precision.
"""

@dataclass(kw_only=True, slots=True)
class Config(Configurable.Config):
dtype: Literal["int4", "int8"] = "int4"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we support activation quantization too? There are pre-set QAT configs in torchao that do this, e.g. Int8DynamicActivationIntxWeightConfig (used primarily for edge) can be passed to QATConfig as a base config (docs here, a bit outdated).

Another example that I think will be important is NVFP4DynamicActivationNVFP4WeightConfig. I don't think a single dtype field will be sufficient in capturing that

"""Data type for fake quantization. Supported: 'int4', 'int8'."""

group_size: int = 256

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way not all quantization configs have group_size, you can see the list of supported QAT configs here

"""Group size for per-group weight quantization.
Must divide in_features of all Linear layers in the model."""

def __init__(self, config: Config, **kwargs):
self.dtype = config.dtype
self.group_size = config.group_size
logger.info(
f"QAT training active (dtype={self.dtype}, group_size={self.group_size})"
)

def convert(self, model: nn.Module) -> None:
from torchao.quantization.qat import FakeQuantizedLinear, IntxFakeQuantizeConfig
from torchao.quantization.quant_primitives import TorchAODType

dtype_map = {
"int4": TorchAODType.INT4,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was only for torch 2.5 and before (which didn't have torch.int4), I think now you can just use torch.int4 directly here

"int8": torch.int8,
}
torch_dtype = dtype_map[self.dtype]

weight_config = IntxFakeQuantizeConfig(
dtype=torch_dtype,
group_size=self.group_size,
is_symmetric=True,
)

def _replace_recursive(parent: nn.Module) -> None:
for name, child in list(parent.named_children()):
if isinstance(child, nn.Linear):
fq = FakeQuantizedLinear.from_linear(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for future reference but there's also NVFP4FakeQuantizedLinear that's not a subclass of this. We can add support for this later but just wanted to bring it up early so we're not baked into this specific linear class

child, weight_config=weight_config
)
setattr(parent, name, fq)
else:
_replace_recursive(child)

_replace_recursive(model)
logger.info(
"Swapped to FakeQuantizedLinear layers "
f"(dtype={self.dtype}, group_size={self.group_size})"
)

def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
pass
27 changes: 27 additions & 0 deletions torchtitan/models/llama3/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
OptimizersInBackwardContainer,
)
from torchtitan.components.quantization.float8 import Float8LinearConverter
from torchtitan.components.quantization.qat import QATConverter
from torchtitan.components.validate import Validator
from torchtitan.config import (
ActivationCheckpointConfig,
Expand Down Expand Up @@ -144,6 +145,32 @@ def llama3_debugmodel_qlora() -> Trainer.Config:
return config


def llama3_debugmodel_qat() -> Trainer.Config:
config = llama3_debugmodel()
config.model_converters = ModelConvertersContainer.Config(
converters=[
QATConverter.Config(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to the comment in [1/N]: order matters

],
)
return config


def llama3_debugmodel_qat_lora() -> Trainer.Config:
config = llama3_debugmodel()
# QATConverter must come before LoRAConverter so that LoRA inherits from
# FakeQuantizedLinear, giving fake-quantized base weights + full-precision adapters.
config.model_converters = ModelConvertersContainer.Config(
converters=[
QATConverter.Config(),
LoRAConverter.Config(
rank=8,
alpha=16.0,
),
],
)
return config


def llama3_8b() -> Trainer.Config:
return Trainer.Config(
hf_assets_path="./assets/hf/Llama-3.1-8B",
Expand Down
13 changes: 7 additions & 6 deletions torchtitan/protocols/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,21 @@ def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
def _validate_converter_ordering(converters: list[Configurable.Config]):
"""Validates that converters are in the correct order.

LoRA must come after quantization because quantization replaces nn.Linear
with specialized subclasses (e.g. Float8Linear), and LoRA dynamically
inherits from whatever linear class it wraps.
LoRA must come after quantization and QAT because both replace nn.Linear

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the LoRAConverter already handle the case when the inner linear is FakeQuantizedLinear? I think additionally we will need to make the lora adapters also fake quantized (according to the same configs, e.g. int4). Does this part exist yet?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we follow: Linear -> QAT(Linear) -> LoRA(QAT) that only adds adapters to the QATLinear, so the currently design has lora adapters in dtype same as the origin linear.

with specialized subclasses (e.g. Float8Linear, FakeQuantizedLinear), and
LoRA dynamically inherits from whatever linear class it wraps.
"""
from torchtitan.components.lora import LoRAConverter
from torchtitan.components.quantization.qat import QATConverter

seen_lora = False
for config in converters:
if isinstance(config, LoRAConverter.Config):
seen_lora = True
elif isinstance(config, QuantizationConverter.Config) and seen_lora:
elif isinstance(config, (QuantizationConverter.Config, QATConverter.Config)) and seen_lora:
raise ValueError(
"LoRA converter must come after quantization converters. "
"Quantization replaces nn.Linear with specialized subclasses, "
"LoRA converter must come after quantization and QAT converters. "
"Quantization/QAT replaces nn.Linear with specialized subclasses, "
"and LoRA must wrap the final linear class."
)

Expand Down
Loading