[4/N] Add QAT (quantization-aware training) model converter#2488
[4/N] Add QAT (quantization-aware training) model converter#2488mori360 wants to merge 17 commits intogh/mori360/4/basefrom
Conversation
[ghstack-poisoned]
- Add QATConverter as a standalone model converter (torchtitan/components/quantization/qat.py) that replaces nn.Linear with torchao's FakeQuantizedLinear to simulate int4/int8 weight quantization during training - QAT composes naturally with LoRA via the converters list: when QATConverter runs before LoRAConverter, LoRA dynamically inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat (QAT-only) and llama3_debugmodel_qat_lora (QAT + LoRA composed) debug configs [ghstack-poisoned]
- Add QATConverter as a standalone model converter (torchtitan/components/quantization/qat.py) that replaces nn.Linear with torchao's FakeQuantizedLinear to simulate int4/int8 weight quantization during training - QAT composes naturally with LoRA via the converters list: when QATConverter runs before LoRAConverter, LoRA dynamically inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat (QAT-only) and llama3_debugmodel_qat_lora (QAT + LoRA composed) debug configs [ghstack-poisoned]
- Add QATConverter as a standalone model converter (torchtitan/components/quantization/qat.py) that replaces nn.Linear with torchao's FakeQuantizedLinear to simulate int4/int8 weight quantization during training - QAT composes naturally with LoRA via the converters list: when QATConverter runs before LoRAConverter, LoRA dynamically inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat (QAT-only) and llama3_debugmodel_qat_lora (QAT + LoRA composed) debug configs [ghstack-poisoned]
### Summary - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full precision - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
### Summary - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full precision - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
| config = llama3_debugmodel() | ||
| config.model_converters = ModelConvertersContainer.Config( | ||
| converters=[ | ||
| QATConverter.Config(), |
There was a problem hiding this comment.
similar to the comment in [1/N]: order matters
### Summary - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full precision - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
### Summary - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full precision - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
### Summary - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full precision - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
### Summary - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full precision - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
### Summary - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full precision - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
| from torchao.quantization.quant_primitives import TorchAODType | ||
|
|
||
| dtype_map = { | ||
| "int4": TorchAODType.INT4, |
There was a problem hiding this comment.
This was only for torch 2.5 and before (which didn't have torch.int4), I think now you can just use torch.int4 directly here
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class Config(Configurable.Config): | ||
| dtype: Literal["int4", "int8"] = "int4" |
There was a problem hiding this comment.
Should we support activation quantization too? There are pre-set QAT configs in torchao that do this, e.g. Int8DynamicActivationIntxWeightConfig (used primarily for edge) can be passed to QATConfig as a base config (docs here, a bit outdated).
Another example that I think will be important is NVFP4DynamicActivationNVFP4WeightConfig. I don't think a single dtype field will be sufficient in capturing that
| def _replace_recursive(parent: nn.Module) -> None: | ||
| for name, child in list(parent.named_children()): | ||
| if isinstance(child, nn.Linear): | ||
| fq = FakeQuantizedLinear.from_linear( |
There was a problem hiding this comment.
Just for future reference but there's also NVFP4FakeQuantizedLinear that's not a subclass of this. We can add support for this later but just wanted to bring it up early so we're not baked into this specific linear class
| dtype: Literal["int4", "int8"] = "int4" | ||
| """Data type for fake quantization. Supported: 'int4', 'int8'.""" | ||
|
|
||
| group_size: int = 256 |
There was a problem hiding this comment.
By the way not all quantization configs have group_size, you can see the list of supported QAT configs here
| LoRA must come after quantization because quantization replaces nn.Linear | ||
| with specialized subclasses (e.g. Float8Linear), and LoRA dynamically | ||
| inherits from whatever linear class it wraps. | ||
| LoRA must come after quantization and QAT because both replace nn.Linear |
There was a problem hiding this comment.
Does the LoRAConverter already handle the case when the inner linear is FakeQuantizedLinear? I think additionally we will need to make the lora adapters also fake quantized (according to the same configs, e.g. int4). Does this part exist yet?
There was a problem hiding this comment.
Currently we follow: Linear -> QAT(Linear) -> LoRA(QAT) that only adds adapters to the QATLinear, so the currently design has lora adapters in dtype same as the origin linear.
### Summary - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full precision - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
### Summary - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full precision - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
### Summary - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full precision - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
### Summary - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full precision - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
### Summary - Add QATConverter that replaces nn.Linear with torchao's FakeQuantizedLinear for quantization-aware training - Fake quantization (int4/int8) is applied in the forward pass so the model learns to compensate for quantization error, while stored weights remain in full precision - When composed with LoRA (QATConverter listed before LoRAConverter in converters), LoRA inherits from FakeQuantizedLinear so base weights are fake-quantized while LoRA adapters stay full-precision - Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
### Summary Add a QATConverter that uses torchao's QATConfig to insert fake quantization into nn.Linear modules during training. This enables quantization-aware training as a first-class model converter in torchtitan. QAT composes with LoRA by converter ordering: QAT first replaces nn.Linear → FakeQuantizedLinear, then LoRA inherits from FakeQuantizedLinear, giving fake-quantized base weights with full-precision adapters. - torchtitan/components/quantization/qat.py — New QATConverter with support for 7 torchao quantization schemes: int4_weight_only, intx_weight_only, int8_dynamic_act_intx_weight, float8_dynamic_act_float8_weight, float8_dynamic_act_int4_weight, nvfp4, and mx. A shared _build_base_config() helper maps scheme names to torchao PTQ base configs. - torchtitan/components/lora.py — Add adapter_qat_scheme and adapter_qat_group_size config fields to LoRAConverter, enabling fake quantization of LoRA adapter weights (lora_a/lora_b) independently of base model quantization. Validates that rank % group_size == 0 for schemes that use per-group granularity. - torchtitan/protocols/model_converter.py — Extend converter ordering validation to enforce that QATConverter (like quantization converters) must come before LoRAConverter, since QAT replaces nn.Linear with FakeQuantizedLinear and LoRA must wrap the final linear class. - torchtitan/models/llama3/config_registry.py — Add llama3_debugmodel_qat and llama3_debugmodel_qat_lora debug configs for testing QAT alone and QAT+LoRA composition. - tests/unit_tests/test_model_converter.py — Add tests for dtype preservation, all 7 QAT schemes, unknown scheme error, and group_size warning for unsupported schemes. ### Test Plan - test_qat_preserves_weight_dtype: verifies all parameter dtypes remain unchanged after QAT conversion [ghstack-poisoned]
Summary
Add a QATConverter that uses torchao's QATConfig to insert fake quantization into nn.Linear modules during training. This enables quantization-aware training as a first-class model converter in torchtitan. QAT composes with LoRA by converter ordering: QAT first replaces nn.Linear → FakeQuantizedLinear, then LoRA inherits from FakeQuantizedLinear, giving fake-quantized base weights with full-precision adapters.
Test Plan
Stack from ghstack (oldest at bottom):