Skip to content

Commit fa02102

Browse files
committed
Add QAT (quantization-aware training) model converter
ghstack-source-id: b3672c9 Pull Request resolved: #2488
1 parent deac878 commit fa02102

File tree

4 files changed

+273
-6
lines changed

4 files changed

+273
-6
lines changed

tests/unit_tests/test_model_converter.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from torchtitan.components.lora import LoRAConverter
1313
from torchtitan.components.quantization.float8 import Float8LinearConverter
14+
from torchtitan.components.quantization.qat import QATConverter
1415
from torchtitan.config import ConfigManager
1516
from torchtitan.distributed import ParallelDims
1617
from torchtitan.protocols.model_converter import ModelConvertersContainer
@@ -202,3 +203,88 @@ def test_lora_key_remap_roundtrip():
202203
assert set(rt_sd.keys()) == set(tt_sd.keys())
203204
for k in tt_sd:
204205
assert torch.equal(rt_sd[k], tt_sd[k])
206+
207+
208+
def test_qat_preserves_weight_dtype():
209+
"""QAT converter should not change weight dtype (fake quantization happens in forward)."""
210+
pytest.importorskip("torchao")
211+
212+
model = nn.Sequential(
213+
OrderedDict(
214+
[
215+
("fc1", nn.Linear(64, 64)),
216+
("relu", nn.ReLU()),
217+
("fc2", nn.Linear(64, 64)),
218+
]
219+
)
220+
)
221+
original_dtypes = {name: param.dtype for name, param in model.named_parameters()}
222+
223+
converter = QATConverter(QATConverter.Config(group_size=64))
224+
converter.convert(model)
225+
226+
for name, param in model.named_parameters():
227+
assert (
228+
param.dtype == original_dtypes[name]
229+
), f"'{name}' dtype changed from {original_dtypes[name]} to {param.dtype}"
230+
231+
232+
@pytest.mark.parametrize(
233+
"scheme, group_size, expected_linear_cls",
234+
[
235+
("int4_weight_only", 64, "FakeQuantizedLinear"),
236+
("intx_weight_only", 64, "FakeQuantizedLinear"),
237+
("int8_dynamic_act_intx_weight", 64, "FakeQuantizedLinear"),
238+
("float8_dynamic_act_float8_weight", None, "FakeQuantizedLinear"),
239+
("float8_dynamic_act_int4_weight", None, "FakeQuantizedLinear"),
240+
("nvfp4", None, "NVFP4FakeQuantizedLinear"),
241+
("mx", None, "MXFakeQuantizedLinear"),
242+
],
243+
)
244+
def test_qat_all_schemes(scheme, group_size, expected_linear_cls):
245+
"""Each QAT scheme should replace nn.Linear with the correct fake-quantized class."""
246+
pytest.importorskip("torchao")
247+
248+
model = nn.Sequential(
249+
OrderedDict(
250+
[
251+
("fc1", nn.Linear(64, 64)),
252+
("relu", nn.ReLU()),
253+
("fc2", nn.Linear(64, 64)),
254+
]
255+
)
256+
)
257+
258+
config_kwargs = {"scheme": scheme}
259+
if group_size is not None:
260+
config_kwargs["group_size"] = group_size
261+
converter = QATConverter(QATConverter.Config(**config_kwargs))
262+
converter.convert(model)
263+
264+
# Linear layers should be replaced with the expected class
265+
assert (
266+
type(model.fc1).__name__ == expected_linear_cls
267+
), f"scheme={scheme}: expected {expected_linear_cls}, got {type(model.fc1).__name__}"
268+
assert (
269+
type(model.fc2).__name__ == expected_linear_cls
270+
), f"scheme={scheme}: expected {expected_linear_cls}, got {type(model.fc2).__name__}"
271+
272+
273+
def test_qat_unknown_scheme_raises():
274+
"""QATConverter should raise ValueError for unknown schemes."""
275+
with pytest.raises(ValueError, match="Unknown QAT scheme"):
276+
QATConverter(QATConverter.Config(scheme="not_a_real_scheme"))
277+
278+
279+
def test_qat_group_size_warning_for_unsupported_scheme(caplog):
280+
"""QATConverter should warn when group_size is set for a scheme that ignores it."""
281+
pytest.importorskip("torchao")
282+
import logging
283+
284+
with caplog.at_level(logging.WARNING):
285+
QATConverter(
286+
QATConverter.Config(
287+
scheme="float8_dynamic_act_float8_weight", group_size=64
288+
)
289+
)
290+
assert "does not use group_size" in caplog.text
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
9+
import torch.nn as nn
10+
from torchtitan.config import Configurable
11+
from torchtitan.tools.logging import logger
12+
13+
# Supported scheme names.
14+
_SUPPORTED_SCHEMES = (
15+
"int4_weight_only",
16+
"intx_weight_only",
17+
"int8_dynamic_act_intx_weight",
18+
"float8_dynamic_act_float8_weight",
19+
"float8_dynamic_act_int4_weight",
20+
"nvfp4",
21+
"mx",
22+
)
23+
24+
# Schemes that accept a group_size parameter.
25+
_SCHEMES_WITH_GROUP_SIZE = (
26+
"int4_weight_only",
27+
"intx_weight_only",
28+
"int8_dynamic_act_intx_weight",
29+
)
30+
31+
32+
def _build_base_config(scheme: str, group_size: int):
33+
"""Return a torchao PTQ base config for the given scheme name."""
34+
if scheme == "int4_weight_only":
35+
from torchao.quantization import Int4WeightOnlyConfig
36+
37+
return Int4WeightOnlyConfig(group_size=group_size)
38+
39+
elif scheme == "intx_weight_only":
40+
import torch
41+
from torchao.quantization import IntxWeightOnlyConfig
42+
from torchao.quantization.granularity import PerGroup
43+
44+
int4_dtype = torch.int4 # pyrefly: ignore[missing-attribute]
45+
return IntxWeightOnlyConfig(
46+
weight_dtype=int4_dtype,
47+
granularity=PerGroup(group_size),
48+
)
49+
50+
elif scheme == "int8_dynamic_act_intx_weight":
51+
import torch
52+
from torchao.quantization import Int8DynamicActivationIntxWeightConfig
53+
from torchao.quantization.granularity import PerGroup
54+
55+
int4_dtype = torch.int4 # pyrefly: ignore[missing-attribute]
56+
return Int8DynamicActivationIntxWeightConfig(
57+
weight_dtype=int4_dtype,
58+
weight_granularity=PerGroup(group_size),
59+
)
60+
61+
elif scheme == "float8_dynamic_act_float8_weight":
62+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
63+
64+
return Float8DynamicActivationFloat8WeightConfig()
65+
66+
elif scheme == "float8_dynamic_act_int4_weight":
67+
from torchao.quantization import Float8DynamicActivationInt4WeightConfig
68+
69+
return Float8DynamicActivationInt4WeightConfig()
70+
71+
elif scheme == "nvfp4":
72+
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
73+
74+
return NVFP4DynamicActivationNVFP4WeightConfig()
75+
76+
elif scheme == "mx":
77+
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
78+
79+
return MXDynamicActivationMXWeightConfig()
80+
81+
else:
82+
raise ValueError(
83+
f"Unknown QAT scheme '{scheme}'. Supported: {_SUPPORTED_SCHEMES}"
84+
)
85+
86+
87+
class QATConverter(Configurable):
88+
"""Apply quantization-aware training via torchao's QATConfig.
89+
90+
Uses ``torchao.quantize_(model, QATConfig(base_config, step="prepare"))``
91+
to insert fake quantization into ``nn.Linear`` modules. The ``scheme``
92+
config field selects a torchao PTQ base config, which QATConfig uses to
93+
infer the appropriate fake quantization for both weights and activations.
94+
95+
Supported schemes:
96+
- ``"int4_weight_only"`` — int4 weight-only fake quantization
97+
- ``"intx_weight_only"`` — intx weight-only fake quantization
98+
- ``"int8_dynamic_act_intx_weight"`` — int8 activation + int4 weight
99+
- ``"float8_dynamic_act_float8_weight"`` — float8 activation + float8 weight
100+
- ``"float8_dynamic_act_int4_weight"`` — float8 activation + int4 weight
101+
- ``"nvfp4"`` — NVFP4 dynamic activation + NVFP4 weight
102+
- ``"mx"`` — MX dynamic activation + MX weight
103+
104+
When composed with LoRA (QATConverter listed before LoRAConverter in converters),
105+
LoRA will inherit from FakeQuantizedLinear so base weights are fake-quantized
106+
while LoRA adapters stay full-precision.
107+
"""
108+
109+
@dataclass(kw_only=True, slots=True)
110+
class Config(Configurable.Config):
111+
scheme: str = "int4_weight_only"
112+
"""QAT scheme name. Maps to a torchao PTQ base config.
113+
Supported: 'int4_weight_only', 'intx_weight_only',
114+
'int8_dynamic_act_intx_weight', 'float8_dynamic_act_float8_weight',
115+
'float8_dynamic_act_int4_weight', 'nvfp4', 'mx'."""
116+
117+
group_size: int = 256
118+
"""Group size for per-group weight quantization.
119+
Used by schemes that support per-group granularity
120+
(int4_weight_only, intx_weight_only, int8_dynamic_act_intx_weight).
121+
Must divide in_features of all Linear layers in the model."""
122+
123+
def __init__(self, config: Config, **kwargs):
124+
if config.scheme not in _SUPPORTED_SCHEMES:
125+
raise ValueError(
126+
f"Unknown QAT scheme '{config.scheme}'. "
127+
f"Supported: {_SUPPORTED_SCHEMES}"
128+
)
129+
self.scheme = config.scheme
130+
self.group_size = config.group_size
131+
if config.scheme not in _SCHEMES_WITH_GROUP_SIZE:
132+
logger.warning(
133+
f"QAT scheme '{config.scheme}' does not use group_size, "
134+
f"ignoring group_size={config.group_size}"
135+
)
136+
logger.info(
137+
f"QAT training active (scheme={self.scheme}, group_size={self.group_size})"
138+
)
139+
140+
def convert(self, model: nn.Module) -> None:
141+
from torchao.quantization import quantize_
142+
from torchao.quantization.qat import QATConfig
143+
from torchao.quantization.qat.api import QATStep
144+
145+
base_config = _build_base_config(self.scheme, self.group_size)
146+
quantize_(model, QATConfig(base_config, step=QATStep.PREPARE))
147+
logger.info(
148+
f"Applied QAT fake quantization (scheme={self.scheme}, "
149+
f"group_size={self.group_size})"
150+
)
151+
152+
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
153+
pass

torchtitan/models/llama3/config_registry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
OptimizersInBackwardContainer,
1414
)
1515
from torchtitan.components.quantization.float8 import Float8LinearConverter
16+
from torchtitan.components.quantization.qat import QATConverter
1617
from torchtitan.components.validate import Validator
1718
from torchtitan.config import (
1819
ActivationCheckpointConfig,
@@ -130,6 +131,32 @@ def llama3_debugmodel_lora() -> Trainer.Config:
130131
return config
131132

132133

134+
def llama3_debugmodel_qat() -> Trainer.Config:
135+
config = llama3_debugmodel()
136+
config.model_converters = ModelConvertersContainer.Config(
137+
converters=[
138+
QATConverter.Config(),
139+
],
140+
)
141+
return config
142+
143+
144+
def llama3_debugmodel_qat_lora() -> Trainer.Config:
145+
config = llama3_debugmodel()
146+
# QATConverter must come before LoRAConverter so that LoRA inherits from
147+
# FakeQuantizedLinear, giving fake-quantized base weights + full-precision adapters.
148+
config.model_converters = ModelConvertersContainer.Config(
149+
converters=[
150+
QATConverter.Config(),
151+
LoRAConverter.Config(
152+
rank=8,
153+
alpha=16.0,
154+
),
155+
],
156+
)
157+
return config
158+
159+
133160
def llama3_8b() -> Trainer.Config:
134161
return Trainer.Config(
135162
hf_assets_path="./assets/hf/Llama-3.1-8B",

torchtitan/protocols/model_converter.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,21 @@ def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
8787
def _validate_converter_ordering(converters: list[Configurable.Config]):
8888
"""Validates that converters are in the correct order.
8989
90-
LoRA must come after quantization because quantization replaces nn.Linear
91-
with specialized subclasses (e.g. Float8Linear), and LoRA dynamically
92-
inherits from whatever linear class it wraps.
90+
LoRA must come after quantization and QAT because both replace nn.Linear
91+
with specialized subclasses (e.g. Float8Linear, FakeQuantizedLinear), and
92+
LoRA dynamically inherits from whatever linear class it wraps.
9393
"""
9494
from torchtitan.components.lora import LoRAConverter
95+
from torchtitan.components.quantization.qat import QATConverter
9596

9697
seen_lora = False
9798
for config in converters:
9899
if isinstance(config, LoRAConverter.Config):
99100
seen_lora = True
100-
elif isinstance(config, QuantizationConverter.Config) and seen_lora:
101+
elif isinstance(config, (QuantizationConverter.Config, QATConverter.Config)) and seen_lora:
101102
raise ValueError(
102-
"LoRA converter must come after quantization converters. "
103-
"Quantization replaces nn.Linear with specialized subclasses, "
103+
"LoRA converter must come after quantization and QAT converters. "
104+
"Quantization/QAT replaces nn.Linear with specialized subclasses, "
104105
"and LoRA must wrap the final linear class."
105106
)
106107

0 commit comments

Comments
 (0)