Skip to content

Commit 0f5cb10

Browse files
committed
Add QAT (quantization-aware training) model converter
ghstack-source-id: 0cc296d Pull Request resolved: #2488
1 parent 695a256 commit 0f5cb10

File tree

4 files changed

+273
-6
lines changed

4 files changed

+273
-6
lines changed

tests/unit_tests/test_model_converter.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from torchtitan.components.lora import LoRAConverter
1313
from torchtitan.components.quantization.float8 import Float8LinearConverter
14+
from torchtitan.components.quantization.qat import QATConverter
1415
from torchtitan.config import ConfigManager
1516
from torchtitan.distributed import ParallelDims
1617
from torchtitan.protocols.model_converter import ModelConvertersContainer
@@ -214,3 +215,88 @@ def test_qlora_base_weights_quantized_adapters_full_precision():
214215
assert (
215216
layer.lora_b.weight.dtype == torch.float32
216217
), f"{name}.lora_b.weight should be float32"
218+
219+
220+
def test_qat_preserves_weight_dtype():
221+
"""QAT converter should not change weight dtype (fake quantization happens in forward)."""
222+
pytest.importorskip("torchao")
223+
224+
model = nn.Sequential(
225+
OrderedDict(
226+
[
227+
("fc1", nn.Linear(64, 64)),
228+
("relu", nn.ReLU()),
229+
("fc2", nn.Linear(64, 64)),
230+
]
231+
)
232+
)
233+
original_dtypes = {name: param.dtype for name, param in model.named_parameters()}
234+
235+
converter = QATConverter(QATConverter.Config(group_size=64))
236+
converter.convert(model)
237+
238+
for name, param in model.named_parameters():
239+
assert (
240+
param.dtype == original_dtypes[name]
241+
), f"'{name}' dtype changed from {original_dtypes[name]} to {param.dtype}"
242+
243+
244+
@pytest.mark.parametrize(
245+
"scheme, group_size, expected_linear_cls",
246+
[
247+
("int4_weight_only", 64, "FakeQuantizedLinear"),
248+
("intx_weight_only", 64, "FakeQuantizedLinear"),
249+
("int8_dynamic_act_intx_weight", 64, "FakeQuantizedLinear"),
250+
("float8_dynamic_act_float8_weight", None, "FakeQuantizedLinear"),
251+
("float8_dynamic_act_int4_weight", None, "FakeQuantizedLinear"),
252+
("nvfp4", None, "NVFP4FakeQuantizedLinear"),
253+
("mx", None, "MXFakeQuantizedLinear"),
254+
],
255+
)
256+
def test_qat_all_schemes(scheme, group_size, expected_linear_cls):
257+
"""Each QAT scheme should replace nn.Linear with the correct fake-quantized class."""
258+
pytest.importorskip("torchao")
259+
260+
model = nn.Sequential(
261+
OrderedDict(
262+
[
263+
("fc1", nn.Linear(64, 64)),
264+
("relu", nn.ReLU()),
265+
("fc2", nn.Linear(64, 64)),
266+
]
267+
)
268+
)
269+
270+
config_kwargs = {"scheme": scheme}
271+
if group_size is not None:
272+
config_kwargs["group_size"] = group_size
273+
converter = QATConverter(QATConverter.Config(**config_kwargs))
274+
converter.convert(model)
275+
276+
# Linear layers should be replaced with the expected class
277+
assert (
278+
type(model.fc1).__name__ == expected_linear_cls
279+
), f"scheme={scheme}: expected {expected_linear_cls}, got {type(model.fc1).__name__}"
280+
assert (
281+
type(model.fc2).__name__ == expected_linear_cls
282+
), f"scheme={scheme}: expected {expected_linear_cls}, got {type(model.fc2).__name__}"
283+
284+
285+
def test_qat_unknown_scheme_raises():
286+
"""QATConverter should raise ValueError for unknown schemes."""
287+
with pytest.raises(ValueError, match="Unknown QAT scheme"):
288+
QATConverter(QATConverter.Config(scheme="not_a_real_scheme"))
289+
290+
291+
def test_qat_group_size_warning_for_unsupported_scheme(caplog):
292+
"""QATConverter should warn when group_size is set for a scheme that ignores it."""
293+
pytest.importorskip("torchao")
294+
import logging
295+
296+
with caplog.at_level(logging.WARNING):
297+
QATConverter(
298+
QATConverter.Config(
299+
scheme="float8_dynamic_act_float8_weight", group_size=64
300+
)
301+
)
302+
assert "does not use group_size" in caplog.text
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
9+
import torch.nn as nn
10+
from torchtitan.config import Configurable
11+
from torchtitan.tools.logging import logger
12+
13+
# Supported scheme names.
14+
_SUPPORTED_SCHEMES = (
15+
"int4_weight_only",
16+
"intx_weight_only",
17+
"int8_dynamic_act_intx_weight",
18+
"float8_dynamic_act_float8_weight",
19+
"float8_dynamic_act_int4_weight",
20+
"nvfp4",
21+
"mx",
22+
)
23+
24+
# Schemes that accept a group_size parameter.
25+
_SCHEMES_WITH_GROUP_SIZE = (
26+
"int4_weight_only",
27+
"intx_weight_only",
28+
"int8_dynamic_act_intx_weight",
29+
)
30+
31+
32+
def _build_base_config(scheme: str, group_size: int):
33+
"""Return a torchao PTQ base config for the given scheme name."""
34+
if scheme == "int4_weight_only":
35+
from torchao.quantization import Int4WeightOnlyConfig
36+
37+
return Int4WeightOnlyConfig(group_size=group_size)
38+
39+
elif scheme == "intx_weight_only":
40+
import torch
41+
from torchao.quantization import IntxWeightOnlyConfig
42+
from torchao.quantization.granularity import PerGroup
43+
44+
int4_dtype = torch.int4 # pyrefly: ignore[missing-attribute]
45+
return IntxWeightOnlyConfig(
46+
weight_dtype=int4_dtype,
47+
granularity=PerGroup(group_size),
48+
)
49+
50+
elif scheme == "int8_dynamic_act_intx_weight":
51+
import torch
52+
from torchao.quantization import Int8DynamicActivationIntxWeightConfig
53+
from torchao.quantization.granularity import PerGroup
54+
55+
int4_dtype = torch.int4 # pyrefly: ignore[missing-attribute]
56+
return Int8DynamicActivationIntxWeightConfig(
57+
weight_dtype=int4_dtype,
58+
weight_granularity=PerGroup(group_size),
59+
)
60+
61+
elif scheme == "float8_dynamic_act_float8_weight":
62+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
63+
64+
return Float8DynamicActivationFloat8WeightConfig()
65+
66+
elif scheme == "float8_dynamic_act_int4_weight":
67+
from torchao.quantization import Float8DynamicActivationInt4WeightConfig
68+
69+
return Float8DynamicActivationInt4WeightConfig()
70+
71+
elif scheme == "nvfp4":
72+
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
73+
74+
return NVFP4DynamicActivationNVFP4WeightConfig()
75+
76+
elif scheme == "mx":
77+
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
78+
79+
return MXDynamicActivationMXWeightConfig()
80+
81+
else:
82+
raise ValueError(
83+
f"Unknown QAT scheme '{scheme}'. Supported: {_SUPPORTED_SCHEMES}"
84+
)
85+
86+
87+
class QATConverter(Configurable):
88+
"""Apply quantization-aware training via torchao's QATConfig.
89+
90+
Uses ``torchao.quantize_(model, QATConfig(base_config, step="prepare"))``
91+
to insert fake quantization into ``nn.Linear`` modules. The ``scheme``
92+
config field selects a torchao PTQ base config, which QATConfig uses to
93+
infer the appropriate fake quantization for both weights and activations.
94+
95+
Supported schemes:
96+
- ``"int4_weight_only"`` — int4 weight-only fake quantization
97+
- ``"intx_weight_only"`` — intx weight-only fake quantization
98+
- ``"int8_dynamic_act_intx_weight"`` — int8 activation + int4 weight
99+
- ``"float8_dynamic_act_float8_weight"`` — float8 activation + float8 weight
100+
- ``"float8_dynamic_act_int4_weight"`` — float8 activation + int4 weight
101+
- ``"nvfp4"`` — NVFP4 dynamic activation + NVFP4 weight
102+
- ``"mx"`` — MX dynamic activation + MX weight
103+
104+
When composed with LoRA (QATConverter listed before LoRAConverter in converters),
105+
LoRA will inherit from FakeQuantizedLinear so base weights are fake-quantized
106+
while LoRA adapters stay full-precision.
107+
"""
108+
109+
@dataclass(kw_only=True, slots=True)
110+
class Config(Configurable.Config):
111+
scheme: str = "int4_weight_only"
112+
"""QAT scheme name. Maps to a torchao PTQ base config.
113+
Supported: 'int4_weight_only', 'intx_weight_only',
114+
'int8_dynamic_act_intx_weight', 'float8_dynamic_act_float8_weight',
115+
'float8_dynamic_act_int4_weight', 'nvfp4', 'mx'."""
116+
117+
group_size: int = 256
118+
"""Group size for per-group weight quantization.
119+
Used by schemes that support per-group granularity
120+
(int4_weight_only, intx_weight_only, int8_dynamic_act_intx_weight).
121+
Must divide in_features of all Linear layers in the model."""
122+
123+
def __init__(self, config: Config, **kwargs):
124+
if config.scheme not in _SUPPORTED_SCHEMES:
125+
raise ValueError(
126+
f"Unknown QAT scheme '{config.scheme}'. "
127+
f"Supported: {_SUPPORTED_SCHEMES}"
128+
)
129+
self.scheme = config.scheme
130+
self.group_size = config.group_size
131+
if config.scheme not in _SCHEMES_WITH_GROUP_SIZE:
132+
logger.warning(
133+
f"QAT scheme '{config.scheme}' does not use group_size, "
134+
f"ignoring group_size={config.group_size}"
135+
)
136+
logger.info(
137+
f"QAT training active (scheme={self.scheme}, group_size={self.group_size})"
138+
)
139+
140+
def convert(self, model: nn.Module) -> None:
141+
from torchao.quantization import quantize_
142+
from torchao.quantization.qat import QATConfig
143+
from torchao.quantization.qat.api import QATStep
144+
145+
base_config = _build_base_config(self.scheme, self.group_size)
146+
quantize_(model, QATConfig(base_config, step=QATStep.PREPARE))
147+
logger.info(
148+
f"Applied QAT fake quantization (scheme={self.scheme}, "
149+
f"group_size={self.group_size})"
150+
)
151+
152+
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
153+
pass

torchtitan/models/llama3/config_registry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
OptimizersInBackwardContainer,
1414
)
1515
from torchtitan.components.quantization.float8 import Float8LinearConverter
16+
from torchtitan.components.quantization.qat import QATConverter
1617
from torchtitan.components.validate import Validator
1718
from torchtitan.config import (
1819
ActivationCheckpointConfig,
@@ -144,6 +145,32 @@ def llama3_debugmodel_qlora() -> Trainer.Config:
144145
return config
145146

146147

148+
def llama3_debugmodel_qat() -> Trainer.Config:
149+
config = llama3_debugmodel()
150+
config.model_converters = ModelConvertersContainer.Config(
151+
converters=[
152+
QATConverter.Config(),
153+
],
154+
)
155+
return config
156+
157+
158+
def llama3_debugmodel_qat_lora() -> Trainer.Config:
159+
config = llama3_debugmodel()
160+
# QATConverter must come before LoRAConverter so that LoRA inherits from
161+
# FakeQuantizedLinear, giving fake-quantized base weights + full-precision adapters.
162+
config.model_converters = ModelConvertersContainer.Config(
163+
converters=[
164+
QATConverter.Config(),
165+
LoRAConverter.Config(
166+
rank=8,
167+
alpha=16.0,
168+
),
169+
],
170+
)
171+
return config
172+
173+
147174
def llama3_8b() -> Trainer.Config:
148175
return Trainer.Config(
149176
hf_assets_path="./assets/hf/Llama-3.1-8B",

torchtitan/protocols/model_converter.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,21 @@ def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
8787
def _validate_converter_ordering(converters: list[Configurable.Config]):
8888
"""Validates that converters are in the correct order.
8989
90-
LoRA must come after quantization because quantization replaces nn.Linear
91-
with specialized subclasses (e.g. Float8Linear), and LoRA dynamically
92-
inherits from whatever linear class it wraps.
90+
LoRA must come after quantization and QAT because both replace nn.Linear
91+
with specialized subclasses (e.g. Float8Linear, FakeQuantizedLinear), and
92+
LoRA dynamically inherits from whatever linear class it wraps.
9393
"""
9494
from torchtitan.components.lora import LoRAConverter
95+
from torchtitan.components.quantization.qat import QATConverter
9596

9697
seen_lora = False
9798
for config in converters:
9899
if isinstance(config, LoRAConverter.Config):
99100
seen_lora = True
100-
elif isinstance(config, QuantizationConverter.Config) and seen_lora:
101+
elif isinstance(config, (QuantizationConverter.Config, QATConverter.Config)) and seen_lora:
101102
raise ValueError(
102-
"LoRA converter must come after quantization converters. "
103-
"Quantization replaces nn.Linear with specialized subclasses, "
103+
"LoRA converter must come after quantization and QAT converters. "
104+
"Quantization/QAT replaces nn.Linear with specialized subclasses, "
104105
"and LoRA must wrap the final linear class."
105106
)
106107

0 commit comments

Comments
 (0)