Skip to content

Commit 1e5030a

Browse files
committed
Add QAT (quantization-aware training) model converter
ghstack-source-id: 29c6878 Pull Request resolved: #2488
1 parent 32fc92b commit 1e5030a

File tree

3 files changed

+112
-3
lines changed

3 files changed

+112
-3
lines changed

torchtitan/components/lora.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class LoRAConverter(Configurable):
8484
"""Apply LoRA adapters to all Linear layers in a model."""
8585

8686
@dataclass(kw_only=True, slots=True)
87-
class Config(Configurable.Config):
87+
class LoRAConfig(Configurable.Config):
8888
rank: int = 8
8989
"""Rank of the LoRA matrices (lora_a: in_features x rank, lora_b: rank x out_features)."""
9090

@@ -104,7 +104,10 @@ class Config(Configurable.Config):
104104
"""Scaler block size for NF4 quantization. Default 128 works with debugmodel on 8 GPUs.
105105
The default torchao value (256) may be too large for sharded tensors."""
106106

107-
def __init__(self, config: Config, **kwargs):
107+
# Alias for backwards compatibility
108+
Config = LoRAConfig
109+
110+
def __init__(self, config: LoRAConfig, **kwargs):
108111
self.rank = config.rank
109112
self.alpha = config.alpha
110113
self.save_adapter_only = config.save_adapter_only
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import Literal
9+
10+
import torch
11+
import torch.nn as nn
12+
from torchtitan.config import Configurable
13+
from torchtitan.tools.logging import logger
14+
15+
16+
class QATConverter(Configurable):
17+
"""Replace nn.Linear with FakeQuantizedLinear for quantization-aware training.
18+
19+
Uses torchao's FakeQuantizedLinear to simulate int4 weight quantization during
20+
training. The fake quantization is applied in the forward pass so the model
21+
learns to compensate for quantization error.
22+
23+
When composed with LoRA (QATConverter listed before LoRAConverter in converters),
24+
LoRA will inherit from FakeQuantizedLinear so base weights are fake-quantized
25+
while LoRA adapters stay full-precision.
26+
"""
27+
28+
@dataclass(kw_only=True, slots=True)
29+
class QATConfig(Configurable.Config):
30+
dtype: Literal["int4", "int8"] = "int4"
31+
"""Data type for fake quantization. Supported: 'int4', 'int8'."""
32+
33+
group_size: int = 256
34+
"""Group size for per-group weight quantization.
35+
Must divide in_features of all Linear layers in the model."""
36+
37+
# Alias for backwards compatibility
38+
Config = QATConfig
39+
40+
def __init__(self, config: QATConfig, **kwargs):
41+
self.dtype = config.dtype
42+
self.group_size = config.group_size
43+
logger.info(
44+
f"QAT training active (dtype={self.dtype}, group_size={self.group_size})"
45+
)
46+
47+
def convert(self, model: nn.Module) -> None:
48+
from torchao.quantization.qat import FakeQuantizedLinear, IntxFakeQuantizeConfig
49+
50+
dtype_map = {
51+
"int4": torch.int4,
52+
"int8": torch.int8,
53+
}
54+
torch_dtype = dtype_map[self.dtype]
55+
56+
weight_config = IntxFakeQuantizeConfig(
57+
dtype=torch_dtype,
58+
group_size=self.group_size,
59+
is_symmetric=True,
60+
)
61+
62+
def _replace_recursive(parent: nn.Module) -> None:
63+
for name, child in list(parent.named_children()):
64+
if isinstance(child, nn.Linear):
65+
fq = FakeQuantizedLinear.from_linear(
66+
child, weight_config=weight_config
67+
)
68+
setattr(parent, name, fq)
69+
else:
70+
_replace_recursive(child)
71+
72+
_replace_recursive(model)
73+
logger.info(
74+
"Swapped to FakeQuantizedLinear layers "
75+
f"(dtype={self.dtype}, group_size={self.group_size})"
76+
)
77+
78+
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
79+
pass

torchtitan/models/llama3/config_registry.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
OptimizersInBackwardContainer,
1414
)
1515
from torchtitan.components.quantization.float8 import Float8LinearConverter
16+
from torchtitan.components.quantization.qat import QATConverter
1617
from torchtitan.components.validate import Validator
1718
from torchtitan.config import (
1819
ActivationCheckpointConfig,
@@ -131,7 +132,7 @@ def llama3_debugmodel_lora() -> Trainer.Config:
131132

132133

133134
def llama3_debugmodel_qlora() -> Trainer.Config:
134-
config = llama3_debugmodel_lora()
135+
config = llama3_debugmodel()
135136
config.model_converters = ModelConvertersContainer.Config(
136137
converters=[
137138
LoRAConverter.Config(
@@ -144,6 +145,32 @@ def llama3_debugmodel_qlora() -> Trainer.Config:
144145
return config
145146

146147

148+
def llama3_debugmodel_qat() -> Trainer.Config:
149+
config = llama3_debugmodel()
150+
config.model_converters = ModelConvertersContainer.Config(
151+
converters=[
152+
QATConverter.Config(),
153+
],
154+
)
155+
return config
156+
157+
158+
def llama3_debugmodel_qat_lora() -> Trainer.Config:
159+
config = llama3_debugmodel()
160+
# QATConverter must come before LoRAConverter so that LoRA inherits from
161+
# FakeQuantizedLinear, giving fake-quantized base weights + full-precision adapters.
162+
config.model_converters = ModelConvertersContainer.Config(
163+
converters=[
164+
QATConverter.Config(),
165+
LoRAConverter.Config(
166+
rank=8,
167+
alpha=16.0,
168+
),
169+
],
170+
)
171+
return config
172+
173+
147174
def llama3_8b() -> Trainer.Config:
148175
return Trainer.Config(
149176
hf_assets_path="./assets/hf/Llama-3.1-8B",

0 commit comments

Comments
 (0)