-
Notifications
You must be signed in to change notification settings - Fork 749
[4/N] Add QAT (quantization-aware training) model converter #2488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/mori360/4/base
Are you sure you want to change the base?
Changes from 11 commits
901c97a
bcc2b2c
858b1d4
828ee0c
0728515
8beee82
a248b79
3ee65be
59029b2
27c4951
4300ee6
1a2826f
1206b0b
c1ad90a
afa28f0
5e623d1
5a39621
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Literal | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torchtitan.config import Configurable | ||
| from torchtitan.tools.logging import logger | ||
|
|
||
|
|
||
| class QATConverter(Configurable): | ||
| """Replace nn.Linear with FakeQuantizedLinear for quantization-aware training. | ||
|
|
||
| Uses torchao's FakeQuantizedLinear to simulate int4 weight quantization during | ||
| training. The fake quantization is applied in the forward pass so the model | ||
| learns to compensate for quantization error. | ||
|
|
||
| When composed with LoRA (QATConverter listed before LoRAConverter in converters), | ||
| LoRA will inherit from FakeQuantizedLinear so base weights are fake-quantized | ||
| while LoRA adapters stay full-precision. | ||
| """ | ||
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class Config(Configurable.Config): | ||
| dtype: Literal["int4", "int8"] = "int4" | ||
| """Data type for fake quantization. Supported: 'int4', 'int8'.""" | ||
|
|
||
| group_size: int = 256 | ||
|
||
| """Group size for per-group weight quantization. | ||
| Must divide in_features of all Linear layers in the model.""" | ||
|
|
||
| def __init__(self, config: Config, **kwargs): | ||
| self.dtype = config.dtype | ||
| self.group_size = config.group_size | ||
| logger.info( | ||
| f"QAT training active (dtype={self.dtype}, group_size={self.group_size})" | ||
| ) | ||
|
|
||
| def convert(self, model: nn.Module) -> None: | ||
| from torchao.quantization.qat import FakeQuantizedLinear, IntxFakeQuantizeConfig | ||
| from torchao.quantization.quant_primitives import TorchAODType | ||
|
|
||
| dtype_map = { | ||
| "int4": TorchAODType.INT4, | ||
|
||
| "int8": torch.int8, | ||
| } | ||
| torch_dtype = dtype_map[self.dtype] | ||
|
|
||
| weight_config = IntxFakeQuantizeConfig( | ||
| dtype=torch_dtype, | ||
| group_size=self.group_size, | ||
| is_symmetric=True, | ||
| ) | ||
|
|
||
| def _replace_recursive(parent: nn.Module) -> None: | ||
| for name, child in list(parent.named_children()): | ||
| if isinstance(child, nn.Linear): | ||
| fq = FakeQuantizedLinear.from_linear( | ||
|
||
| child, weight_config=weight_config | ||
| ) | ||
| setattr(parent, name, fq) | ||
| else: | ||
| _replace_recursive(child) | ||
|
|
||
| _replace_recursive(model) | ||
| logger.info( | ||
| "Swapped to FakeQuantizedLinear layers " | ||
| f"(dtype={self.dtype}, group_size={self.group_size})" | ||
| ) | ||
|
|
||
| def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None: | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| OptimizersInBackwardContainer, | ||
| ) | ||
| from torchtitan.components.quantization.float8 import Float8LinearConverter | ||
| from torchtitan.components.quantization.qat import QATConverter | ||
| from torchtitan.components.validate import Validator | ||
| from torchtitan.config import ( | ||
| ActivationCheckpointConfig, | ||
|
|
@@ -144,6 +145,32 @@ def llama3_debugmodel_qlora() -> Trainer.Config: | |
| return config | ||
|
|
||
|
|
||
| def llama3_debugmodel_qat() -> Trainer.Config: | ||
| config = llama3_debugmodel() | ||
| config.model_converters = ModelConvertersContainer.Config( | ||
| converters=[ | ||
| QATConverter.Config(), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar to the comment in [1/N]: order matters |
||
| ], | ||
| ) | ||
| return config | ||
|
|
||
|
|
||
| def llama3_debugmodel_qat_lora() -> Trainer.Config: | ||
| config = llama3_debugmodel() | ||
| # QATConverter must come before LoRAConverter so that LoRA inherits from | ||
| # FakeQuantizedLinear, giving fake-quantized base weights + full-precision adapters. | ||
| config.model_converters = ModelConvertersContainer.Config( | ||
| converters=[ | ||
| QATConverter.Config(), | ||
| LoRAConverter.Config( | ||
| rank=8, | ||
| alpha=16.0, | ||
| ), | ||
| ], | ||
| ) | ||
| return config | ||
|
|
||
|
|
||
| def llama3_8b() -> Trainer.Config: | ||
| return Trainer.Config( | ||
| hf_assets_path="./assets/hf/Llama-3.1-8B", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,20 +87,21 @@ def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): | |
| def _validate_converter_ordering(converters: list[Configurable.Config]): | ||
| """Validates that converters are in the correct order. | ||
|
|
||
| LoRA must come after quantization because quantization replaces nn.Linear | ||
| with specialized subclasses (e.g. Float8Linear), and LoRA dynamically | ||
| inherits from whatever linear class it wraps. | ||
| LoRA must come after quantization and QAT because both replace nn.Linear | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently we follow: Linear -> QAT(Linear) -> LoRA(QAT) that only adds adapters to the QATLinear, so the currently design has lora adapters in dtype same as the origin linear. |
||
| with specialized subclasses (e.g. Float8Linear, FakeQuantizedLinear), and | ||
| LoRA dynamically inherits from whatever linear class it wraps. | ||
| """ | ||
| from torchtitan.components.lora import LoRAConverter | ||
| from torchtitan.components.quantization.qat import QATConverter | ||
|
|
||
| seen_lora = False | ||
| for config in converters: | ||
| if isinstance(config, LoRAConverter.Config): | ||
| seen_lora = True | ||
| elif isinstance(config, QuantizationConverter.Config) and seen_lora: | ||
| elif isinstance(config, (QuantizationConverter.Config, QATConverter.Config)) and seen_lora: | ||
| raise ValueError( | ||
| "LoRA converter must come after quantization converters. " | ||
| "Quantization replaces nn.Linear with specialized subclasses, " | ||
| "LoRA converter must come after quantization and QAT converters. " | ||
| "Quantization/QAT replaces nn.Linear with specialized subclasses, " | ||
| "and LoRA must wrap the final linear class." | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we support activation quantization too? There are pre-set QAT configs in torchao that do this, e.g.
Int8DynamicActivationIntxWeightConfig(used primarily for edge) can be passed toQATConfigas a base config (docs here, a bit outdated).Another example that I think will be important is
NVFP4DynamicActivationNVFP4WeightConfig. I don't think a singledtypefield will be sufficient in capturing that