Skip to content

[4/N] Add QAT (quantization-aware training) model converter#2488

Draft
mori360 wants to merge 17 commits intogh/mori360/4/basefrom
gh/mori360/4/head
Draft

[4/N] Add QAT (quantization-aware training) model converter#2488
mori360 wants to merge 17 commits intogh/mori360/4/basefrom
gh/mori360/4/head

Conversation

@mori360
Copy link
Contributor

@mori360 mori360 commented Mar 5, 2026

Summary

Add a QATConverter that uses torchao's QATConfig to insert fake quantization into nn.Linear modules during training. This enables quantization-aware training as a first-class model converter in torchtitan. QAT composes with LoRA by converter ordering: QAT first replaces nn.Linear → FakeQuantizedLinear, then LoRA inherits from FakeQuantizedLinear, giving fake-quantized base weights with full-precision adapters.

  • torchtitan/components/quantization/qat.py — New QATConverter with support for 7 torchao quantization schemes: int4_weight_only, intx_weight_only, int8_dynamic_act_intx_weight, float8_dynamic_act_float8_weight, float8_dynamic_act_int4_weight, nvfp4, and mx. A shared _build_base_config() helper maps scheme names to torchao PTQ base configs.
  • torchtitan/components/lora.py — Add adapter_qat_scheme and adapter_qat_group_size config fields to LoRAConverter, enabling fake quantization of LoRA adapter weights (lora_a/lora_b) independently of base model quantization. Validates that rank % group_size == 0 for schemes that use per-group granularity.
  • torchtitan/protocols/model_converter.py — Extend converter ordering validation to enforce that QATConverter (like quantization converters) must come before LoRAConverter, since QAT replaces nn.Linear with FakeQuantizedLinear and LoRA must wrap the final linear class.
  • torchtitan/models/llama3/config_registry.py — Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs for testing QAT alone and QAT+LoRA composition.
  • tests/unit_tests/test_model_converter.py — Add tests for dtype preservation, all 7 QAT schemes, unknown scheme error, and group_size warning for unsupported schemes.

Test Plan

  • test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion

Stack from ghstack (oldest at bottom):

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 5, 2026
mori360 added a commit that referenced this pull request Mar 5, 2026
ghstack-source-id: 29c6878
Pull Request resolved: #2488
@mori360 mori360 changed the title Add QAT (quantization-aware training) model converter [4/N] Add QAT (quantization-aware training) model converter Mar 5, 2026
@mori360 mori360 marked this pull request as draft March 5, 2026 01:18
  - Add QATConverter as a standalone model converter (torchtitan/components/quantization/qat.py) that replaces nn.Linear with torchao's FakeQuantizedLinear to
   simulate int4/int8 weight quantization during training
  - QAT composes naturally with LoRA via the converters list: when QATConverter runs before LoRAConverter, LoRA dynamically inherits from FakeQuantizedLinear
  so base weights are fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat (QAT-only) and llama3_debugmodel_qat_lora (QAT + LoRA composed) debug configs




[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 5, 2026
ghstack-source-id: 0f213e5
Pull Request resolved: #2488
  - Add QATConverter as a standalone model converter (torchtitan/components/quantization/qat.py) that replaces nn.Linear with torchao's FakeQuantizedLinear to
   simulate int4/int8 weight quantization during training
  - QAT composes naturally with LoRA via the converters list: when QATConverter runs before LoRAConverter, LoRA dynamically inherits from FakeQuantizedLinear
  so base weights are fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat (QAT-only) and llama3_debugmodel_qat_lora (QAT + LoRA composed) debug configs




[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 5, 2026
ghstack-source-id: 2ff472d
Pull Request resolved: #2488
  - Add QATConverter as a standalone model converter (torchtitan/components/quantization/qat.py) that replaces nn.Linear with torchao's FakeQuantizedLinear to
   simulate int4/int8 weight quantization during training
  - QAT composes naturally with LoRA via the converters list: when QATConverter runs before LoRAConverter, LoRA dynamically inherits from FakeQuantizedLinear
  so base weights are fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat (QAT-only) and llama3_debugmodel_qat_lora (QAT + LoRA composed) debug configs




[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 5, 2026
ghstack-source-id: d7e680b
Pull Request resolved: #2488
### Summary

  - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training
  - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full
   precision
  - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are
  fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 6, 2026
ghstack-source-id: 1aa91c6
Pull Request resolved: #2488
### Summary

  - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training
  - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full
   precision
  - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are
  fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 6, 2026
ghstack-source-id: 3c51ad0
Pull Request resolved: #2488
config = llama3_debugmodel()
config.model_converters = ModelConvertersContainer.Config(
converters=[
QATConverter.Config(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to the comment in [1/N]: order matters

@tianyu-l tianyu-l requested a review from andrewor14 March 8, 2026 22:43
### Summary

  - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training
  - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full
   precision
  - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are
  fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 12, 2026
ghstack-source-id: b139735
Pull Request resolved: #2488
### Summary

  - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training
  - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full
   precision
  - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are
  fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 12, 2026
ghstack-source-id: 1d82b1d
Pull Request resolved: #2488
### Summary

  - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training
  - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full
   precision
  - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are
  fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 12, 2026
ghstack-source-id: 6ff4bf8
Pull Request resolved: #2488
### Summary

  - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training
  - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full
   precision
  - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are
  fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 12, 2026
ghstack-source-id: bbdd2fc
Pull Request resolved: #2488
### Summary

  - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training
  - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full
   precision
  - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are
  fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
@mori360 mori360 mentioned this pull request Mar 13, 2026
Copy link

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

from torchao.quantization.quant_primitives import TorchAODType

dtype_map = {
"int4": TorchAODType.INT4,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was only for torch 2.5 and before (which didn't have torch.int4), I think now you can just use torch.int4 directly here


@dataclass(kw_only=True, slots=True)
class Config(Configurable.Config):
dtype: Literal["int4", "int8"] = "int4"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we support activation quantization too? There are pre-set QAT configs in torchao that do this, e.g. Int8DynamicActivationIntxWeightConfig (used primarily for edge) can be passed to QATConfig as a base config (docs here, a bit outdated).

Another example that I think will be important is NVFP4DynamicActivationNVFP4WeightConfig. I don't think a single dtype field will be sufficient in capturing that

def _replace_recursive(parent: nn.Module) -> None:
for name, child in list(parent.named_children()):
if isinstance(child, nn.Linear):
fq = FakeQuantizedLinear.from_linear(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for future reference but there's also NVFP4FakeQuantizedLinear that's not a subclass of this. We can add support for this later but just wanted to bring it up early so we're not baked into this specific linear class

dtype: Literal["int4", "int8"] = "int4"
"""Data type for fake quantization. Supported: 'int4', 'int8'."""

group_size: int = 256

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way not all quantization configs have group_size, you can see the list of supported QAT configs here

LoRA must come after quantization because quantization replaces nn.Linear
with specialized subclasses (e.g. Float8Linear), and LoRA dynamically
inherits from whatever linear class it wraps.
LoRA must come after quantization and QAT because both replace nn.Linear

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the LoRAConverter already handle the case when the inner linear is FakeQuantizedLinear? I think additionally we will need to make the lora adapters also fake quantized (according to the same configs, e.g. int4). Does this part exist yet?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we follow: Linear -> QAT(Linear) -> LoRA(QAT) that only adds adapters to the QATLinear, so the currently design has lora adapters in dtype same as the origin linear.

mori360 added 6 commits March 13, 2026 14:35
### Summary

  - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training
  - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full
   precision
  - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are
  fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
### Summary

  - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training
  - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full
   precision
  - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are
  fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
### Summary

  - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training
  - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full
   precision
  - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are
  fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
### Summary

  - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training
  - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full
   precision
  - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are
  fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
### Summary

  - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training
  - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full
   precision
  - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are
  fake-quantized while LoRA adapters stay full-precision
  - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
### Summary

Add a QATConverter that uses torchao's QATConfig to insert fake quantization into nn.Linear modules during training. This enables quantization-aware training as a first-class model converter in torchtitan. QAT composes with LoRA by converter ordering: QAT first replaces nn.Linear → FakeQuantizedLinear, then LoRA inherits from FakeQuantizedLinear, giving fake-quantized base weights with full-precision adapters.

 - torchtitan/components/quantization/qat.py — New QATConverter with support for 7 torchao quantization schemes: int4_weight_only, intx_weight_only, int8_dynamic_act_intx_weight, float8_dynamic_act_float8_weight, float8_dynamic_act_int4_weight, nvfp4, and mx. A shared _build_base_config() helper maps scheme names to torchao PTQ base configs.
  - torchtitan/components/lora.py — Add adapter_qat_scheme and adapter_qat_group_size config fields to LoRAConverter, enabling fake quantization of LoRA adapter weights (lora_a/lora_b) independently of base model quantization. Validates that rank % group_size == 0 for schemes that use per-group granularity.
  - torchtitan/protocols/model_converter.py — Extend converter ordering validation to enforce that QATConverter (like quantization converters) must come before LoRAConverter, since QAT replaces nn.Linear with FakeQuantizedLinear and LoRA must wrap the final linear class.
  - torchtitan/models/llama3/config_registry.py — Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs for testing QAT alone and QAT+LoRA composition.
  - tests/unit_tests/test_model_converter.py — Add tests for dtype preservation, all 7 QAT schemes, unknown scheme error, and group_size warning for unsupported schemes.

### Test Plan

- test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion




[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants