Skip to content

Commit 732ec7d

Browse files
committed
Add QAT (quantization-aware training) model converter
ghstack-source-id: 2ff472d Pull Request resolved: #2488
1 parent ffe6dba commit 732ec7d

File tree

3 files changed

+122
-0
lines changed

3 files changed

+122
-0
lines changed

tests/unit_tests/test_model_converter.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from torchtitan.components.lora import LoRAConverter
1111
from torchtitan.components.quantization.float8 import Float8LinearConverter
12+
from torchtitan.components.quantization.qat import QATConverter
1213
from torchtitan.config import ConfigManager
1314
from torchtitan.distributed import ParallelDims
1415
from torchtitan.protocols.model_converter import ModelConvertersContainer
@@ -187,3 +188,21 @@ def test_qlora_base_weights_quantized_adapters_full_precision():
187188
assert (
188189
layer.lora_b.weight.dtype == torch.float32
189190
), f"{name}.lora_b.weight should be float32"
191+
192+
193+
def test_qat_preserves_weight_dtype():
194+
"""QAT converter should not change weight dtype (fake quantization happens in forward)."""
195+
pytest.importorskip("torchao")
196+
197+
model = SimpleModel()
198+
original_dtypes = {
199+
name: param.dtype for name, param in model.named_parameters()
200+
}
201+
202+
converter = QATConverter(QATConverter.Config(group_size=64))
203+
converter.convert(model)
204+
205+
for name, param in model.named_parameters():
206+
assert (
207+
param.dtype == original_dtypes[name]
208+
), f"'{name}' dtype changed from {original_dtypes[name]} to {param.dtype}"
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import Literal
9+
10+
import torch
11+
import torch.nn as nn
12+
from torchtitan.config import Configurable
13+
from torchtitan.tools.logging import logger
14+
15+
16+
class QATConverter(Configurable):
17+
"""Replace nn.Linear with FakeQuantizedLinear for quantization-aware training.
18+
19+
Uses torchao's FakeQuantizedLinear to simulate int4 weight quantization during
20+
training. The fake quantization is applied in the forward pass so the model
21+
learns to compensate for quantization error.
22+
23+
When composed with LoRA (QATConverter listed before LoRAConverter in converters),
24+
LoRA will inherit from FakeQuantizedLinear so base weights are fake-quantized
25+
while LoRA adapters stay full-precision.
26+
"""
27+
28+
@dataclass(kw_only=True, slots=True)
29+
class Config(Configurable.Config):
30+
dtype: Literal["int4", "int8"] = "int4"
31+
"""Data type for fake quantization. Supported: 'int4', 'int8'."""
32+
33+
group_size: int = 256
34+
"""Group size for per-group weight quantization.
35+
Must divide in_features of all Linear layers in the model."""
36+
37+
def __init__(self, config: Config, **kwargs):
38+
self.dtype = config.dtype
39+
self.group_size = config.group_size
40+
logger.info(
41+
f"QAT training active (dtype={self.dtype}, group_size={self.group_size})"
42+
)
43+
44+
def convert(self, model: nn.Module) -> None:
45+
from torchao.quantization.qat import FakeQuantizedLinear, IntxFakeQuantizeConfig
46+
47+
dtype_map = {
48+
"int4": torch.int4,
49+
"int8": torch.int8,
50+
}
51+
torch_dtype = dtype_map[self.dtype]
52+
53+
weight_config = IntxFakeQuantizeConfig(
54+
dtype=torch_dtype,
55+
group_size=self.group_size,
56+
is_symmetric=True,
57+
)
58+
59+
def _replace_recursive(parent: nn.Module) -> None:
60+
for name, child in list(parent.named_children()):
61+
if isinstance(child, nn.Linear):
62+
fq = FakeQuantizedLinear.from_linear(
63+
child, weight_config=weight_config
64+
)
65+
setattr(parent, name, fq)
66+
else:
67+
_replace_recursive(child)
68+
69+
_replace_recursive(model)
70+
logger.info(
71+
"Swapped to FakeQuantizedLinear layers "
72+
f"(dtype={self.dtype}, group_size={self.group_size})"
73+
)
74+
75+
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
76+
pass

torchtitan/models/llama3/config_registry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
OptimizersInBackwardContainer,
1414
)
1515
from torchtitan.components.quantization.float8 import Float8LinearConverter
16+
from torchtitan.components.quantization.qat import QATConverter
1617
from torchtitan.components.validate import Validator
1718
from torchtitan.config import (
1819
ActivationCheckpointConfig,
@@ -144,6 +145,32 @@ def llama3_debugmodel_qlora() -> Trainer.Config:
144145
return config
145146

146147

148+
def llama3_debugmodel_qat() -> Trainer.Config:
149+
config = llama3_debugmodel()
150+
config.model_converters = ModelConvertersContainer.Config(
151+
converters=[
152+
QATConverter.Config(),
153+
],
154+
)
155+
return config
156+
157+
158+
def llama3_debugmodel_qat_lora() -> Trainer.Config:
159+
config = llama3_debugmodel()
160+
# QATConverter must come before LoRAConverter so that LoRA inherits from
161+
# FakeQuantizedLinear, giving fake-quantized base weights + full-precision adapters.
162+
config.model_converters = ModelConvertersContainer.Config(
163+
converters=[
164+
QATConverter.Config(),
165+
LoRAConverter.Config(
166+
rank=8,
167+
alpha=16.0,
168+
),
169+
],
170+
)
171+
return config
172+
173+
147174
def llama3_8b() -> Trainer.Config:
148175
return Trainer.Config(
149176
hf_assets_path="./assets/hf/Llama-3.1-8B",

0 commit comments

Comments
 (0)