Skip to content

Commit 34ef778

Browse files
committed
Add QAT (quantization-aware training) model converter
ghstack-source-id: 84aff39 Pull Request resolved: #2488
1 parent 0a1f1aa commit 34ef778

File tree

5 files changed

+339
-6
lines changed

5 files changed

+339
-6
lines changed

tests/unit_tests/test_model_converter.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from torchtitan.components.lora import LoRAConverter
1313
from torchtitan.components.quantization.float8 import Float8LinearConverter
14+
from torchtitan.components.quantization.qat import QATConverter
1415
from torchtitan.config import ConfigManager
1516
from torchtitan.distributed import ParallelDims
1617
from torchtitan.protocols.model_converter import ModelConvertersContainer
@@ -202,3 +203,88 @@ def test_lora_key_remap_roundtrip():
202203
assert set(rt_sd.keys()) == set(tt_sd.keys())
203204
for k in tt_sd:
204205
assert torch.equal(rt_sd[k], tt_sd[k])
206+
207+
208+
def test_qat_preserves_weight_dtype():
209+
"""QAT converter should not change weight dtype (fake quantization happens in forward)."""
210+
pytest.importorskip("torchao")
211+
212+
model = nn.Sequential(
213+
OrderedDict(
214+
[
215+
("fc1", nn.Linear(64, 64)),
216+
("relu", nn.ReLU()),
217+
("fc2", nn.Linear(64, 64)),
218+
]
219+
)
220+
)
221+
original_dtypes = {name: param.dtype for name, param in model.named_parameters()}
222+
223+
converter = QATConverter(QATConverter.Config(group_size=64))
224+
converter.convert(model)
225+
226+
for name, param in model.named_parameters():
227+
assert (
228+
param.dtype == original_dtypes[name]
229+
), f"'{name}' dtype changed from {original_dtypes[name]} to {param.dtype}"
230+
231+
232+
@pytest.mark.parametrize(
233+
"scheme, group_size, expected_linear_cls",
234+
[
235+
("int4_weight_only", 64, "FakeQuantizedLinear"),
236+
("intx_weight_only", 64, "FakeQuantizedLinear"),
237+
("int8_dynamic_act_intx_weight", 64, "FakeQuantizedLinear"),
238+
("float8_dynamic_act_float8_weight", None, "FakeQuantizedLinear"),
239+
("float8_dynamic_act_int4_weight", None, "FakeQuantizedLinear"),
240+
("nvfp4", None, "NVFP4FakeQuantizedLinear"),
241+
("mx", None, "MXFakeQuantizedLinear"),
242+
],
243+
)
244+
def test_qat_all_schemes(scheme, group_size, expected_linear_cls):
245+
"""Each QAT scheme should replace nn.Linear with the correct fake-quantized class."""
246+
pytest.importorskip("torchao")
247+
248+
model = nn.Sequential(
249+
OrderedDict(
250+
[
251+
("fc1", nn.Linear(64, 64)),
252+
("relu", nn.ReLU()),
253+
("fc2", nn.Linear(64, 64)),
254+
]
255+
)
256+
)
257+
258+
config_kwargs = {"scheme": scheme}
259+
if group_size is not None:
260+
config_kwargs["group_size"] = group_size
261+
converter = QATConverter(QATConverter.Config(**config_kwargs))
262+
converter.convert(model)
263+
264+
# Linear layers should be replaced with the expected class
265+
assert (
266+
type(model.fc1).__name__ == expected_linear_cls
267+
), f"scheme={scheme}: expected {expected_linear_cls}, got {type(model.fc1).__name__}"
268+
assert (
269+
type(model.fc2).__name__ == expected_linear_cls
270+
), f"scheme={scheme}: expected {expected_linear_cls}, got {type(model.fc2).__name__}"
271+
272+
273+
def test_qat_unknown_scheme_raises():
274+
"""QATConverter should raise ValueError for unknown schemes."""
275+
with pytest.raises(ValueError, match="Unknown QAT scheme"):
276+
QATConverter(QATConverter.Config(scheme="not_a_real_scheme"))
277+
278+
279+
def test_qat_group_size_warning_for_unsupported_scheme(caplog):
280+
"""QATConverter should warn when group_size is set for a scheme that ignores it."""
281+
pytest.importorskip("torchao")
282+
import logging
283+
284+
with caplog.at_level(logging.WARNING):
285+
QATConverter(
286+
QATConverter.Config(
287+
scheme="float8_dynamic_act_float8_weight", group_size=64
288+
)
289+
)
290+
assert "does not use group_size" in caplog.text

torchtitan/components/lora.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ class Config(Configurable.Config):
9898
"merged" folds adapters into base weights (base + alpha/rank * B @ A)
9999
and saves a standard checkpoint with no LoRA keys."""
100100

101+
adapter_qat_scheme: str = ""
102+
"""QAT scheme for adapter weights. Empty = no adapter QAT.
103+
Must match a supported QATConverter scheme."""
104+
105+
adapter_qat_group_size: int = 8
106+
"""Group size for adapter weight quantization.
107+
Must divide rank (i.e. rank % group_size == 0).
108+
Only used by schemes that support per-group granularity."""
109+
101110
def __init__(self, config: Config, **kwargs):
102111
self.rank = config.rank
103112
self.alpha = config.alpha
@@ -107,6 +116,33 @@ def __init__(self, config: Config, **kwargs):
107116
f"LoRA save_format must be 'dcp', 'peft', or 'merged', "
108117
f"got '{self.save_format}'"
109118
)
119+
120+
self.adapter_qat_scheme = config.adapter_qat_scheme
121+
self.adapter_qat_group_size = config.adapter_qat_group_size
122+
if self.adapter_qat_scheme:
123+
from torchtitan.components.quantization.qat import (
124+
_SCHEMES_WITH_GROUP_SIZE,
125+
_SUPPORTED_SCHEMES,
126+
)
127+
128+
if self.adapter_qat_scheme not in _SUPPORTED_SCHEMES:
129+
raise ValueError(
130+
f"Unknown adapter QAT scheme '{self.adapter_qat_scheme}'. "
131+
f"Supported: {_SUPPORTED_SCHEMES}"
132+
)
133+
if self.adapter_qat_scheme in _SCHEMES_WITH_GROUP_SIZE:
134+
if self.rank % self.adapter_qat_group_size != 0:
135+
raise ValueError(
136+
f"LoRA rank ({self.rank}) must be divisible by "
137+
f"adapter_qat_group_size ({self.adapter_qat_group_size})"
138+
)
139+
else:
140+
logger.warning(
141+
f"Adapter QAT scheme '{self.adapter_qat_scheme}' does not use "
142+
f"group_size, ignoring adapter_qat_group_size="
143+
f"{self.adapter_qat_group_size}"
144+
)
145+
110146
logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")
111147

112148
@staticmethod
@@ -148,6 +184,9 @@ def convert(self, model: nn.Module) -> None:
148184
model.requires_grad_(False)
149185
self._replace_linears_with_lora(model)
150186

187+
if self.adapter_qat_scheme:
188+
self._apply_adapter_qat(model)
189+
151190
# Wire up checkpoint filtering so ModelWrapper knows which keys
152191
# are adapter keys and how to save them.
153192
model.converter_key_filter = self._is_lora_key # type: ignore[attr-defined]
@@ -160,6 +199,33 @@ def convert(self, model: nn.Module) -> None:
160199
if self.save_format == "merged":
161200
model.converter_export_sd_fn = self._make_merge_fn() # type: ignore[attr-defined]
162201

202+
def _apply_adapter_qat(self, model: nn.Module) -> None:
203+
from torchao.quantization import quantize_
204+
from torchao.quantization.qat import QATConfig
205+
from torchao.quantization.qat.api import QATStep
206+
207+
from torchtitan.components.quantization.qat import _build_base_config
208+
209+
base_config = _build_base_config(
210+
self.adapter_qat_scheme, self.adapter_qat_group_size
211+
)
212+
213+
def _is_lora_linear(mod: nn.Module, fqn: str) -> bool:
214+
return isinstance(mod, nn.Linear) and (
215+
fqn.endswith(".lora_a") or fqn.endswith(".lora_b")
216+
)
217+
218+
quantize_(
219+
model,
220+
QATConfig(base_config, step=QATStep.PREPARE),
221+
filter_fn=_is_lora_linear,
222+
)
223+
logger.info(
224+
f"Applied adapter QAT fake quantization "
225+
f"(scheme={self.adapter_qat_scheme}, "
226+
f"group_size={self.adapter_qat_group_size})"
227+
)
228+
163229
def _replace_linears_with_lora(self, module: nn.Module) -> None:
164230
for _, child in list(module.named_modules()):
165231
if isinstance(child, nn.Linear):
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
9+
import torch.nn as nn
10+
from torchtitan.config import Configurable
11+
from torchtitan.tools.logging import logger
12+
13+
# Supported scheme names.
14+
_SUPPORTED_SCHEMES = (
15+
"int4_weight_only",
16+
"intx_weight_only",
17+
"int8_dynamic_act_intx_weight",
18+
"float8_dynamic_act_float8_weight",
19+
"float8_dynamic_act_int4_weight",
20+
"nvfp4",
21+
"mx",
22+
)
23+
24+
# Schemes that accept a group_size parameter.
25+
_SCHEMES_WITH_GROUP_SIZE = (
26+
"int4_weight_only",
27+
"intx_weight_only",
28+
"int8_dynamic_act_intx_weight",
29+
)
30+
31+
32+
def _build_base_config(scheme: str, group_size: int):
33+
"""Return a torchao PTQ base config for the given scheme name."""
34+
if scheme == "int4_weight_only":
35+
from torchao.quantization import Int4WeightOnlyConfig
36+
37+
return Int4WeightOnlyConfig(group_size=group_size)
38+
39+
elif scheme == "intx_weight_only":
40+
import torch
41+
from torchao.quantization import IntxWeightOnlyConfig
42+
from torchao.quantization.granularity import PerGroup
43+
44+
int4_dtype = torch.int4 # pyrefly: ignore[missing-attribute]
45+
return IntxWeightOnlyConfig(
46+
weight_dtype=int4_dtype,
47+
granularity=PerGroup(group_size),
48+
)
49+
50+
elif scheme == "int8_dynamic_act_intx_weight":
51+
import torch
52+
from torchao.quantization import Int8DynamicActivationIntxWeightConfig
53+
from torchao.quantization.granularity import PerGroup
54+
55+
int4_dtype = torch.int4 # pyrefly: ignore[missing-attribute]
56+
return Int8DynamicActivationIntxWeightConfig(
57+
weight_dtype=int4_dtype,
58+
weight_granularity=PerGroup(group_size),
59+
)
60+
61+
elif scheme == "float8_dynamic_act_float8_weight":
62+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
63+
64+
return Float8DynamicActivationFloat8WeightConfig()
65+
66+
elif scheme == "float8_dynamic_act_int4_weight":
67+
from torchao.quantization import Float8DynamicActivationInt4WeightConfig
68+
69+
return Float8DynamicActivationInt4WeightConfig()
70+
71+
elif scheme == "nvfp4":
72+
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
73+
74+
return NVFP4DynamicActivationNVFP4WeightConfig()
75+
76+
elif scheme == "mx":
77+
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
78+
79+
return MXDynamicActivationMXWeightConfig()
80+
81+
else:
82+
raise ValueError(
83+
f"Unknown QAT scheme '{scheme}'. Supported: {_SUPPORTED_SCHEMES}"
84+
)
85+
86+
87+
class QATConverter(Configurable):
88+
"""Apply quantization-aware training via torchao's QATConfig.
89+
90+
Uses ``torchao.quantize_(model, QATConfig(base_config, step="prepare"))``
91+
to insert fake quantization into ``nn.Linear`` modules. The ``scheme``
92+
config field selects a torchao PTQ base config, which QATConfig uses to
93+
infer the appropriate fake quantization for both weights and activations.
94+
95+
Supported schemes:
96+
- ``"int4_weight_only"`` — int4 weight-only fake quantization
97+
- ``"intx_weight_only"`` — intx weight-only fake quantization
98+
- ``"int8_dynamic_act_intx_weight"`` — int8 activation + int4 weight
99+
- ``"float8_dynamic_act_float8_weight"`` — float8 activation + float8 weight
100+
- ``"float8_dynamic_act_int4_weight"`` — float8 activation + int4 weight
101+
- ``"nvfp4"`` — NVFP4 dynamic activation + NVFP4 weight
102+
- ``"mx"`` — MX dynamic activation + MX weight
103+
104+
When composed with LoRA (QATConverter listed before LoRAConverter in converters),
105+
LoRA will inherit from FakeQuantizedLinear so base weights are fake-quantized
106+
while LoRA adapters stay full-precision.
107+
"""
108+
109+
@dataclass(kw_only=True, slots=True)
110+
class Config(Configurable.Config):
111+
scheme: str = "int4_weight_only"
112+
"""QAT scheme name. Maps to a torchao PTQ base config.
113+
Supported: 'int4_weight_only', 'intx_weight_only',
114+
'int8_dynamic_act_intx_weight', 'float8_dynamic_act_float8_weight',
115+
'float8_dynamic_act_int4_weight', 'nvfp4', 'mx'."""
116+
117+
group_size: int = 256
118+
"""Group size for per-group weight quantization.
119+
Used by schemes that support per-group granularity
120+
(int4_weight_only, intx_weight_only, int8_dynamic_act_intx_weight).
121+
Must divide in_features of all Linear layers in the model."""
122+
123+
def __init__(self, config: Config, **kwargs):
124+
if config.scheme not in _SUPPORTED_SCHEMES:
125+
raise ValueError(
126+
f"Unknown QAT scheme '{config.scheme}'. "
127+
f"Supported: {_SUPPORTED_SCHEMES}"
128+
)
129+
self.scheme = config.scheme
130+
self.group_size = config.group_size
131+
if config.scheme not in _SCHEMES_WITH_GROUP_SIZE:
132+
logger.warning(
133+
f"QAT scheme '{config.scheme}' does not use group_size, "
134+
f"ignoring group_size={config.group_size}"
135+
)
136+
logger.info(
137+
f"QAT training active (scheme={self.scheme}, group_size={self.group_size})"
138+
)
139+
140+
def convert(self, model: nn.Module) -> None:
141+
from torchao.quantization import quantize_
142+
from torchao.quantization.qat import QATConfig
143+
from torchao.quantization.qat.api import QATStep
144+
145+
base_config = _build_base_config(self.scheme, self.group_size)
146+
quantize_(model, QATConfig(base_config, step=QATStep.PREPARE))
147+
logger.info(
148+
f"Applied QAT fake quantization (scheme={self.scheme}, "
149+
f"group_size={self.group_size})"
150+
)
151+
152+
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
153+
pass

torchtitan/models/llama3/config_registry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
OptimizersInBackwardContainer,
1414
)
1515
from torchtitan.components.quantization.float8 import Float8LinearConverter
16+
from torchtitan.components.quantization.qat import QATConverter
1617
from torchtitan.components.validate import Validator
1718
from torchtitan.config import (
1819
ActivationCheckpointConfig,
@@ -130,6 +131,32 @@ def llama3_debugmodel_lora() -> Trainer.Config:
130131
return config
131132

132133

134+
def llama3_debugmodel_qat() -> Trainer.Config:
135+
config = llama3_debugmodel()
136+
config.model_converters = ModelConvertersContainer.Config(
137+
converters=[
138+
QATConverter.Config(),
139+
],
140+
)
141+
return config
142+
143+
144+
def llama3_debugmodel_qat_lora() -> Trainer.Config:
145+
config = llama3_debugmodel()
146+
# QATConverter must come before LoRAConverter so that LoRA inherits from
147+
# FakeQuantizedLinear, giving fake-quantized base weights + full-precision adapters.
148+
config.model_converters = ModelConvertersContainer.Config(
149+
converters=[
150+
QATConverter.Config(),
151+
LoRAConverter.Config(
152+
rank=8,
153+
alpha=16.0,
154+
),
155+
],
156+
)
157+
return config
158+
159+
133160
def llama3_8b() -> Trainer.Config:
134161
return Trainer.Config(
135162
hf_assets_path="./assets/hf/Llama-3.1-8B",

0 commit comments

Comments
 (0)