Skip to content

Commit 9501b3d

Browse files
committed
Add QAT (quantization-aware training) model converter
ghstack-source-id: b139735 Pull Request resolved: #2488
1 parent 6b775ca commit 9501b3d

File tree

3 files changed

+121
-0
lines changed

3 files changed

+121
-0
lines changed

tests/unit_tests/test_model_converter.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from torchtitan.components.lora import LoRAConverter
1111
from torchtitan.components.quantization.float8 import Float8LinearConverter
12+
from torchtitan.components.quantization.qat import QATConverter
1213
from torchtitan.config import ConfigManager
1314
from torchtitan.distributed import ParallelDims
1415
from torchtitan.protocols.model_converter import ModelConvertersContainer
@@ -198,3 +199,19 @@ def test_qlora_base_weights_quantized_adapters_full_precision():
198199
assert (
199200
layer.lora_b.weight.dtype == torch.float32
200201
), f"{name}.lora_b.weight should be float32"
202+
203+
204+
def test_qat_preserves_weight_dtype():
205+
"""QAT converter should not change weight dtype (fake quantization happens in forward)."""
206+
pytest.importorskip("torchao")
207+
208+
model = SimpleModel()
209+
original_dtypes = {name: param.dtype for name, param in model.named_parameters()}
210+
211+
converter = QATConverter(QATConverter.Config(group_size=64))
212+
converter.convert(model)
213+
214+
for name, param in model.named_parameters():
215+
assert (
216+
param.dtype == original_dtypes[name]
217+
), f"'{name}' dtype changed from {original_dtypes[name]} to {param.dtype}"
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import Literal
9+
10+
import torch
11+
import torch.nn as nn
12+
from torchtitan.config import Configurable
13+
from torchtitan.tools.logging import logger
14+
15+
16+
class QATConverter(Configurable):
17+
"""Replace nn.Linear with FakeQuantizedLinear for quantization-aware training.
18+
19+
Uses torchao's FakeQuantizedLinear to simulate int4 weight quantization during
20+
training. The fake quantization is applied in the forward pass so the model
21+
learns to compensate for quantization error.
22+
23+
When composed with LoRA (QATConverter listed before LoRAConverter in converters),
24+
LoRA will inherit from FakeQuantizedLinear so base weights are fake-quantized
25+
while LoRA adapters stay full-precision.
26+
"""
27+
28+
@dataclass(kw_only=True, slots=True)
29+
class Config(Configurable.Config):
30+
dtype: Literal["int4", "int8"] = "int4"
31+
"""Data type for fake quantization. Supported: 'int4', 'int8'."""
32+
33+
group_size: int = 256
34+
"""Group size for per-group weight quantization.
35+
Must divide in_features of all Linear layers in the model."""
36+
37+
def __init__(self, config: Config, **kwargs):
38+
self.dtype = config.dtype
39+
self.group_size = config.group_size
40+
logger.info(
41+
f"QAT training active (dtype={self.dtype}, group_size={self.group_size})"
42+
)
43+
44+
def convert(self, model: nn.Module) -> None:
45+
from torchao.quantization.qat import FakeQuantizedLinear, IntxFakeQuantizeConfig
46+
from torchao.quantization.quant_primitives import TorchAODType
47+
48+
dtype_map = {
49+
"int4": TorchAODType.INT4,
50+
"int8": torch.int8,
51+
}
52+
torch_dtype = dtype_map[self.dtype]
53+
54+
weight_config = IntxFakeQuantizeConfig(
55+
dtype=torch_dtype,
56+
group_size=self.group_size,
57+
is_symmetric=True,
58+
)
59+
60+
def _replace_recursive(parent: nn.Module) -> None:
61+
for name, child in list(parent.named_children()):
62+
if isinstance(child, nn.Linear):
63+
fq = FakeQuantizedLinear.from_linear(
64+
child, weight_config=weight_config
65+
)
66+
setattr(parent, name, fq)
67+
else:
68+
_replace_recursive(child)
69+
70+
_replace_recursive(model)
71+
logger.info(
72+
"Swapped to FakeQuantizedLinear layers "
73+
f"(dtype={self.dtype}, group_size={self.group_size})"
74+
)
75+
76+
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
77+
pass

torchtitan/models/llama3/config_registry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
OptimizersInBackwardContainer,
1414
)
1515
from torchtitan.components.quantization.float8 import Float8LinearConverter
16+
from torchtitan.components.quantization.qat import QATConverter
1617
from torchtitan.components.validate import Validator
1718
from torchtitan.config import (
1819
ActivationCheckpointConfig,
@@ -144,6 +145,32 @@ def llama3_debugmodel_qlora() -> Trainer.Config:
144145
return config
145146

146147

148+
def llama3_debugmodel_qat() -> Trainer.Config:
149+
config = llama3_debugmodel()
150+
config.model_converters = ModelConvertersContainer.Config(
151+
converters=[
152+
QATConverter.Config(),
153+
],
154+
)
155+
return config
156+
157+
158+
def llama3_debugmodel_qat_lora() -> Trainer.Config:
159+
config = llama3_debugmodel()
160+
# QATConverter must come before LoRAConverter so that LoRA inherits from
161+
# FakeQuantizedLinear, giving fake-quantized base weights + full-precision adapters.
162+
config.model_converters = ModelConvertersContainer.Config(
163+
converters=[
164+
QATConverter.Config(),
165+
LoRAConverter.Config(
166+
rank=8,
167+
alpha=16.0,
168+
),
169+
],
170+
)
171+
return config
172+
173+
147174
def llama3_8b() -> Trainer.Config:
148175
return Trainer.Config(
149176
hf_assets_path="./assets/hf/Llama-3.1-8B",

0 commit comments

Comments
 (0)