[1/N] Add lora adapters for finetune#2484
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
### Summary
To support LoRA (Low-Rank Adaptation) in torchtitan, this PR is the first step that adds LoRA adapters into the model and enables LoRA finetuning.
LoRALinear - A dynamically generated Linear layer subclass that adds trainable low-rank adapters (lora_a, lora_b) while keeping base weights frozen. Uses the decomposition output = base_output + (α/r) * B(A(x)) where r is rank and α is scaling factor. Compatible with any nn.Linear subclass (including Float8Linear).
LoRAConverter - A model converter that:
- Recursively replaces all nn.Linear layers with their LoRA equivalents
- Freezes base model weights (requires_grad=False)
- Overrides init_weights() to properly initialize LoRA adapters on the target device
### Details
1. Enable LoRA via ModelConverter
LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config:
```
[model]
converters = ["lora"]
[lora]
rank = 8
alpha = 16.0
dropout = 0.0
```
2. In-place Linear Replacement with LoRALinear Subclass
The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters:
- Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel)
- In-place modification: Instead of creating new modules, we create subclass preserving existing weights
- Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable
- Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"]
LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together
### Test Plan
Tested with Llama3 debug model on 8 GPU
- lora has loess memory cost
- lora could work with float8
without lora converter
```
[rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step: 1 loss: 8.2000 grad_norm: 1.3955 memory: 1.01GiB(1.07%) tps: 5,479 tflops: 0.39 mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step: 5 loss: 5.3874 grad_norm: 2.3446 memory: 1.06GiB(1.11%) tps: 343,441 tflops: 24.59 mfu: 2.49%
[rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10 loss: 4.1315 grad_norm: 1.8140 memory: 1.06GiB(1.11%) tps: 320,740 tflops: 22.96 mfu: 2.32%
```
with lora converter
```
[rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step: 1 loss: 8.1225 grad_norm: 0.3293 memory: 0.98GiB(1.03%) tps: 6,222 tflops: 0.46 mfu: 0.05%
[rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step: 5 loss: 8.0054 grad_norm: 0.5108 memory: 1.02GiB(1.07%) tps: 274,052 tflops: 20.05 mfu: 2.03%
[rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10 loss: 7.8191 grad_norm: 0.7809 memory: 1.02GiB(1.07%) tps: 266,976 tflops: 19.53 mfu: 1.98%
```
with lora converter + float8
```
[rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step: 1 loss: 8.1689 grad_norm: 0.2351 memory: 1.00GiB(1.05%) tps: 5,067 tflops: 0.37 mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step: 5 loss: 8.0590 grad_norm: 0.3869 memory: 1.04GiB(1.10%) tps: 134,235 tflops: 9.82 mfu: 0.99%
[rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10 loss: 7.9155 grad_norm: 0.6083 memory: 1.04GiB(1.10%) tps: 122,904 tflops: 8.99 mfu: 0.91%
```
### Future work
There are some todo works that not included in this PR and will be future work
1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion:
- Keep lora_adapters as Replicate so that they won't be shardded by TP
- Generate layer_plan for lora adapters so that they will be processed by TP as well
2. Support supervised finetuning
[ghstack-poisoned]
### Summary
To support LoRA (Low-Rank Adaptation) in torchtitan, this PR is the first step that adds LoRA adapters into the model and enables LoRA finetuning.
LoRALinear - A dynamically generated Linear layer subclass that adds trainable low-rank adapters (lora_a, lora_b) while keeping base weights frozen. Uses the decomposition output = base_output + (α/r) * B(A(x)) where r is rank and α is scaling factor. Compatible with any nn.Linear subclass (including Float8Linear).
LoRAConverter - A model converter that:
- Recursively replaces all nn.Linear layers with their LoRA equivalents
- Freezes base model weights (requires_grad=False)
- Overrides init_weights() to properly initialize LoRA adapters on the target device
### Details
1. Enable LoRA via ModelConverter
LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config:
```
[model]
converters = ["lora"]
[lora]
rank = 8
alpha = 16.0
dropout = 0.0
```
2. In-place Linear Replacement with LoRALinear Subclass
The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters:
- Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel)
- In-place modification: Instead of creating new modules, we create subclass preserving existing weights
- Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable
- Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"]
LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together
### Test Plan
Tested with Llama3 debug model on 8 GPU
- lora has loess memory cost
- lora could work with float8
without lora converter
```
[rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step: 1 loss: 8.2000 grad_norm: 1.3955 memory: 1.01GiB(1.07%) tps: 5,479 tflops: 0.39 mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step: 5 loss: 5.3874 grad_norm: 2.3446 memory: 1.06GiB(1.11%) tps: 343,441 tflops: 24.59 mfu: 2.49%
[rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10 loss: 4.1315 grad_norm: 1.8140 memory: 1.06GiB(1.11%) tps: 320,740 tflops: 22.96 mfu: 2.32%
```
with lora converter
```
[rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step: 1 loss: 8.1225 grad_norm: 0.3293 memory: 0.98GiB(1.03%) tps: 6,222 tflops: 0.46 mfu: 0.05%
[rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step: 5 loss: 8.0054 grad_norm: 0.5108 memory: 1.02GiB(1.07%) tps: 274,052 tflops: 20.05 mfu: 2.03%
[rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10 loss: 7.8191 grad_norm: 0.7809 memory: 1.02GiB(1.07%) tps: 266,976 tflops: 19.53 mfu: 1.98%
```
with lora converter + float8
```
[rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step: 1 loss: 8.1689 grad_norm: 0.2351 memory: 1.00GiB(1.05%) tps: 5,067 tflops: 0.37 mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step: 5 loss: 8.0590 grad_norm: 0.3869 memory: 1.04GiB(1.10%) tps: 134,235 tflops: 9.82 mfu: 0.99%
[rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10 loss: 7.9155 grad_norm: 0.6083 memory: 1.04GiB(1.10%) tps: 122,904 tflops: 8.99 mfu: 0.91%
```
### Future work
There are some todo works that not included in this PR and will be future work
1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion:
- Keep lora_adapters as Replicate so that they won't be shardded by TP
- Generate layer_plan for lora adapters so that they will be processed by TP as well
2. Support supervised finetuning
[ghstack-poisoned]
torchtitan/components/lora.py
Outdated
| if original_init_weights is not None and callable(original_init_weights): | ||
| original_init_weights(*args, **kwargs) | ||
| for sub_module in module.modules(): | ||
| if type(sub_module) in _lora_class_cache.values(): |
There was a problem hiding this comment.
It seems we need this because nn.Linear doesn't have init_weights and instead in torchtitan we directly operating on the weight and bias in any nn.Linear submodule.
I believe @fegin will refactor this to create our own Linear module with init_weights capability. So let's leave a TODO here to remove this new_model_init_weights.
|
|
||
| def llama3_debugmodel_lora() -> Trainer.Config: | ||
| config = llama3_debugmodel() | ||
| config.model_converters = ModelConvertersContainer.Config( |
There was a problem hiding this comment.
Among all the converters (quantization, lora, etc.), order matters. For now let's create a __post_init__ in ModelConvertersContainer.Config to validate the order.
Note that it can't capture the wrong order, if we first init the main config and then modify a field.
@fegin Validation is becoming an issue. If we don't go with chz, maybe we should at least freeze the config, and require users to use dataclass.replace to trigger post init checks.
### Summary - Add LoRAConverter model converter that applies low-rank adaptation (LoRA) to all nn.Linear layers in a model - LoRA uses dynamic subclass creation (__class__ swap) to wrap existing Linear layers, preserving compatibility with any nn.Linear subclass (e.g., FakeQuantizedLinear) - Base model weights are frozen (requires_grad=False), only LoRA adapter weights (lora_a, lora_b) are trainable - Patches init_weights on the model to reinitialize LoRA adapters during the meta-device → real-device init flow - Add llama3_debugmodel_lora debug config and unit tests for freeze/trainability and training correctness #2515 ### Test Plan - test_lora_freeze_and_trainability: check LoRA params have requires_grad=True, base params (original weights/biases) have requires_grad=False. - test_lora_trains_base_frozen, after training, asserts every base param is bitwise identical to its snapshot. Asserts at least one LoRA param has changed. ### Details 1. Enable LoRA via ModelConverter LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config: ``` [model] converters = ["lora"] [lora] rank = 8 alpha = 16.0 dropout = 0.0 ``` 2. In-place Linear Replacement with LoRALinear Subclass The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters: - Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel) - In-place modification: Instead of creating new modules, we create subclass preserving existing weights - Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable - Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"] LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together ### Test Plan Tested with Llama3 debug model on 8 GPU - lora has loess memory cost - lora could work with float8 without lora converter ``` [rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step: 1 loss: 8.2000 grad_norm: 1.3955 memory: 1.01GiB(1.07%) tps: 5,479 tflops: 0.39 mfu: 0.04% [rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step: 5 loss: 5.3874 grad_norm: 2.3446 memory: 1.06GiB(1.11%) tps: 343,441 tflops: 24.59 mfu: 2.49% [rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10 loss: 4.1315 grad_norm: 1.8140 memory: 1.06GiB(1.11%) tps: 320,740 tflops: 22.96 mfu: 2.32% ``` with lora converter ``` [rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step: 1 loss: 8.1225 grad_norm: 0.3293 memory: 0.98GiB(1.03%) tps: 6,222 tflops: 0.46 mfu: 0.05% [rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step: 5 loss: 8.0054 grad_norm: 0.5108 memory: 1.02GiB(1.07%) tps: 274,052 tflops: 20.05 mfu: 2.03% [rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10 loss: 7.8191 grad_norm: 0.7809 memory: 1.02GiB(1.07%) tps: 266,976 tflops: 19.53 mfu: 1.98% ``` with lora converter + float8 ``` [rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step: 1 loss: 8.1689 grad_norm: 0.2351 memory: 1.00GiB(1.05%) tps: 5,067 tflops: 0.37 mfu: 0.04% [rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step: 5 loss: 8.0590 grad_norm: 0.3869 memory: 1.04GiB(1.10%) tps: 134,235 tflops: 9.82 mfu: 0.99% [rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10 loss: 7.9155 grad_norm: 0.6083 memory: 1.04GiB(1.10%) tps: 122,904 tflops: 8.99 mfu: 0.91% ``` ### Future work There are some todo works that not included in this PR and will be future work 1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion: - Keep lora_adapters as Replicate so that they won't be shardded by TP - Generate layer_plan for lora adapters so that they will be processed by TP as well 2. Support supervised finetuning [ghstack-poisoned]
### Summary - Add LoRAConverter model converter that applies low-rank adaptation (LoRA) to all nn.Linear layers in a model - LoRA uses dynamic subclass creation (__class__ swap) to wrap existing Linear layers, preserving compatibility with any nn.Linear subclass (e.g., FakeQuantizedLinear) - Base model weights are frozen (requires_grad=False), only LoRA adapter weights (lora_a, lora_b) are trainable - Patches init_weights on the model to reinitialize LoRA adapters during the meta-device → real-device init flow - Add llama3_debugmodel_lora debug config and unit tests for freeze/trainability and training correctness #2515 ### Test Plan - test_lora_freeze_and_trainability: check LoRA params have requires_grad=True, base params (original weights/biases) have requires_grad=False. - test_lora_trains_base_frozen, after training, asserts every base param is bitwise identical to its snapshot. Asserts at least one LoRA param has changed. ### Details 1. Enable LoRA via ModelConverter LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config: ``` [model] converters = ["lora"] [lora] rank = 8 alpha = 16.0 dropout = 0.0 ``` 2. In-place Linear Replacement with LoRALinear Subclass The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters: - Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel) - In-place modification: Instead of creating new modules, we create subclass preserving existing weights - Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable - Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"] LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together ### Test Plan Tested with Llama3 debug model on 8 GPU - lora has loess memory cost - lora could work with float8 without lora converter ``` [rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step: 1 loss: 8.2000 grad_norm: 1.3955 memory: 1.01GiB(1.07%) tps: 5,479 tflops: 0.39 mfu: 0.04% [rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step: 5 loss: 5.3874 grad_norm: 2.3446 memory: 1.06GiB(1.11%) tps: 343,441 tflops: 24.59 mfu: 2.49% [rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10 loss: 4.1315 grad_norm: 1.8140 memory: 1.06GiB(1.11%) tps: 320,740 tflops: 22.96 mfu: 2.32% ``` with lora converter ``` [rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step: 1 loss: 8.1225 grad_norm: 0.3293 memory: 0.98GiB(1.03%) tps: 6,222 tflops: 0.46 mfu: 0.05% [rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step: 5 loss: 8.0054 grad_norm: 0.5108 memory: 1.02GiB(1.07%) tps: 274,052 tflops: 20.05 mfu: 2.03% [rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10 loss: 7.8191 grad_norm: 0.7809 memory: 1.02GiB(1.07%) tps: 266,976 tflops: 19.53 mfu: 1.98% ``` with lora converter + float8 ``` [rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step: 1 loss: 8.1689 grad_norm: 0.2351 memory: 1.00GiB(1.05%) tps: 5,067 tflops: 0.37 mfu: 0.04% [rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step: 5 loss: 8.0590 grad_norm: 0.3869 memory: 1.04GiB(1.10%) tps: 134,235 tflops: 9.82 mfu: 0.99% [rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10 loss: 7.9155 grad_norm: 0.6083 memory: 1.04GiB(1.10%) tps: 122,904 tflops: 8.99 mfu: 0.91% ``` ### Future work There are some todo works that not included in this PR and will be future work 1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion: - Keep lora_adapters as Replicate so that they won't be shardded by TP - Generate layer_plan for lora adapters so that they will be processed by TP as well 2. Support supervised finetuning [ghstack-poisoned]
### Summary - Add LoRAConverter model converter that applies low-rank adaptation (LoRA) to all nn.Linear layers in a model - LoRA uses dynamic subclass creation (__class__ swap) to wrap existing Linear layers, preserving compatibility with any nn.Linear subclass (e.g., FakeQuantizedLinear) - Base model weights are frozen (requires_grad=False), only LoRA adapter weights (lora_a, lora_b) are trainable - Patches init_weights on the model to reinitialize LoRA adapters during the meta-device → real-device init flow - Add llama3_debugmodel_lora debug config and unit tests for freeze/trainability and training correctness #2515 ### Test Plan - test_lora_freeze_and_trainability: check LoRA params have requires_grad=True, base params (original weights/biases) have requires_grad=False. - test_lora_trains_base_frozen, after training, asserts every base param is bitwise identical to its snapshot. Asserts at least one LoRA param has changed. ### Details 1. Enable LoRA via ModelConverter LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config: ``` [model] converters = ["lora"] [lora] rank = 8 alpha = 16.0 dropout = 0.0 ``` 2. In-place Linear Replacement with LoRALinear Subclass The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters: - Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel) - In-place modification: Instead of creating new modules, we create subclass preserving existing weights - Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable - Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"] LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together ### Test Plan Tested with Llama3 debug model on 8 GPU - lora has loess memory cost - lora could work with float8 without lora converter ``` [rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step: 1 loss: 8.2000 grad_norm: 1.3955 memory: 1.01GiB(1.07%) tps: 5,479 tflops: 0.39 mfu: 0.04% [rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step: 5 loss: 5.3874 grad_norm: 2.3446 memory: 1.06GiB(1.11%) tps: 343,441 tflops: 24.59 mfu: 2.49% [rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10 loss: 4.1315 grad_norm: 1.8140 memory: 1.06GiB(1.11%) tps: 320,740 tflops: 22.96 mfu: 2.32% ``` with lora converter ``` [rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step: 1 loss: 8.1225 grad_norm: 0.3293 memory: 0.98GiB(1.03%) tps: 6,222 tflops: 0.46 mfu: 0.05% [rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step: 5 loss: 8.0054 grad_norm: 0.5108 memory: 1.02GiB(1.07%) tps: 274,052 tflops: 20.05 mfu: 2.03% [rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10 loss: 7.8191 grad_norm: 0.7809 memory: 1.02GiB(1.07%) tps: 266,976 tflops: 19.53 mfu: 1.98% ``` with lora converter + float8 ``` [rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step: 1 loss: 8.1689 grad_norm: 0.2351 memory: 1.00GiB(1.05%) tps: 5,067 tflops: 0.37 mfu: 0.04% [rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step: 5 loss: 8.0590 grad_norm: 0.3869 memory: 1.04GiB(1.10%) tps: 134,235 tflops: 9.82 mfu: 0.99% [rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10 loss: 7.9155 grad_norm: 0.6083 memory: 1.04GiB(1.10%) tps: 122,904 tflops: 8.99 mfu: 0.91% ``` ### Future work There are some todo works that not included in this PR and will be future work 1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion: - Keep lora_adapters as Replicate so that they won't be shardded by TP - Generate layer_plan for lora adapters so that they will be processed by TP as well 2. Support supervised finetuning [ghstack-poisoned]
### Summary - Add LoRAConverter model converter that applies low-rank adaptation (LoRA) to all nn.Linear layers in a model - LoRA uses dynamic subclass creation (__class__ swap) to wrap existing Linear layers, preserving compatibility with any nn.Linear subclass (e.g., FakeQuantizedLinear) - Base model weights are frozen (requires_grad=False), only LoRA adapter weights (lora_a, lora_b) are trainable - Patches init_weights on the model to reinitialize LoRA adapters during the meta-device → real-device init flow - Add llama3_debugmodel_lora debug config and unit tests for freeze/trainability and training correctness #2515 ### Test Plan - test_lora_freeze_and_trainability: check LoRA params have requires_grad=True, base params (original weights/biases) have requires_grad=False. - test_lora_trains_base_frozen, after training, asserts every base param is bitwise identical to its snapshot. Asserts at least one LoRA param has changed. ### Details 1. Enable LoRA via ModelConverter LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config: ``` [model] converters = ["lora"] [lora] rank = 8 alpha = 16.0 dropout = 0.0 ``` 2. In-place Linear Replacement with LoRALinear Subclass The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters: - Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel) - In-place modification: Instead of creating new modules, we create subclass preserving existing weights - Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable - Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"] LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together ### Test Plan Tested with Llama3 debug model on 8 GPU - lora has loess memory cost - lora could work with float8 without lora converter ``` [rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step: 1 loss: 8.2000 grad_norm: 1.3955 memory: 1.01GiB(1.07%) tps: 5,479 tflops: 0.39 mfu: 0.04% [rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step: 5 loss: 5.3874 grad_norm: 2.3446 memory: 1.06GiB(1.11%) tps: 343,441 tflops: 24.59 mfu: 2.49% [rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10 loss: 4.1315 grad_norm: 1.8140 memory: 1.06GiB(1.11%) tps: 320,740 tflops: 22.96 mfu: 2.32% ``` with lora converter ``` [rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step: 1 loss: 8.1225 grad_norm: 0.3293 memory: 0.98GiB(1.03%) tps: 6,222 tflops: 0.46 mfu: 0.05% [rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step: 5 loss: 8.0054 grad_norm: 0.5108 memory: 1.02GiB(1.07%) tps: 274,052 tflops: 20.05 mfu: 2.03% [rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10 loss: 7.8191 grad_norm: 0.7809 memory: 1.02GiB(1.07%) tps: 266,976 tflops: 19.53 mfu: 1.98% ``` with lora converter + float8 ``` [rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step: 1 loss: 8.1689 grad_norm: 0.2351 memory: 1.00GiB(1.05%) tps: 5,067 tflops: 0.37 mfu: 0.04% [rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step: 5 loss: 8.0590 grad_norm: 0.3869 memory: 1.04GiB(1.10%) tps: 134,235 tflops: 9.82 mfu: 0.99% [rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10 loss: 7.9155 grad_norm: 0.6083 memory: 1.04GiB(1.10%) tps: 122,904 tflops: 8.99 mfu: 0.91% ``` ### Future work There are some todo works that not included in this PR and will be future work 1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion: - Keep lora_adapters as Replicate so that they won't be shardded by TP - Generate layer_plan for lora adapters so that they will be processed by TP as well 2. Support supervised finetuning [ghstack-poisoned]
| ) | ||
|
|
||
| def init_weights(self, **kwargs) -> None: | ||
| super().init_weights(**kwargs) # pyrefly: ignore [not-callable] |
There was a problem hiding this comment.
Remove new_model_init_weights after refactoring Linear @fegin
|
We plan to hold on this pr until the pr to refactor parallelism next week, and fix composability issue to tp that time. (will support ep in the pr later) |
### Summary - Add LoRAConverter model converter that applies low-rank adaptation (LoRA) to all nn.Linear layers in a model - LoRA uses dynamic subclass creation (__class__ swap) to wrap existing Linear layers, preserving compatibility with any nn.Linear subclass (e.g., FakeQuantizedLinear) - Base model weights are frozen (requires_grad=False), only LoRA adapter weights (lora_a, lora_b) are trainable - Patches init_weights on the model to reinitialize LoRA adapters during the meta-device → real-device init flow - Add llama3_debugmodel_lora debug config and unit tests for freeze/trainability and training correctness #2515 ### Test Plan - test_lora_freeze_and_trainability: check LoRA params have requires_grad=True, base params (original weights/biases) have requires_grad=False. - test_lora_trains_base_frozen, after training, asserts every base param is bitwise identical to its snapshot. Asserts at least one LoRA param has changed. ### Details 1. Enable LoRA via ModelConverter LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config: ``` [model] converters = ["lora"] [lora] rank = 8 alpha = 16.0 dropout = 0.0 ``` 2. In-place Linear Replacement with LoRALinear Subclass The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters: - Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel) - In-place modification: Instead of creating new modules, we create subclass preserving existing weights - Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable - Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"] LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together ### Test Plan Tested with Llama3 debug model on 8 GPU - lora has loess memory cost - lora could work with float8 without lora converter ``` [rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step: 1 loss: 8.2000 grad_norm: 1.3955 memory: 1.01GiB(1.07%) tps: 5,479 tflops: 0.39 mfu: 0.04% [rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step: 5 loss: 5.3874 grad_norm: 2.3446 memory: 1.06GiB(1.11%) tps: 343,441 tflops: 24.59 mfu: 2.49% [rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10 loss: 4.1315 grad_norm: 1.8140 memory: 1.06GiB(1.11%) tps: 320,740 tflops: 22.96 mfu: 2.32% ``` with lora converter ``` [rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step: 1 loss: 8.1225 grad_norm: 0.3293 memory: 0.98GiB(1.03%) tps: 6,222 tflops: 0.46 mfu: 0.05% [rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step: 5 loss: 8.0054 grad_norm: 0.5108 memory: 1.02GiB(1.07%) tps: 274,052 tflops: 20.05 mfu: 2.03% [rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10 loss: 7.8191 grad_norm: 0.7809 memory: 1.02GiB(1.07%) tps: 266,976 tflops: 19.53 mfu: 1.98% ``` with lora converter + float8 ``` [rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step: 1 loss: 8.1689 grad_norm: 0.2351 memory: 1.00GiB(1.05%) tps: 5,067 tflops: 0.37 mfu: 0.04% [rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step: 5 loss: 8.0590 grad_norm: 0.3869 memory: 1.04GiB(1.10%) tps: 134,235 tflops: 9.82 mfu: 0.99% [rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10 loss: 7.9155 grad_norm: 0.6083 memory: 1.04GiB(1.10%) tps: 122,904 tflops: 8.99 mfu: 0.91% ``` ### Future work There are some todo works that not included in this PR and will be future work 1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion: - Keep lora_adapters as Replicate so that they won't be shardded by TP - Generate layer_plan for lora adapters so that they will be processed by TP as well 2. Support supervised finetuning [ghstack-poisoned]
…ave and multi-source load" ### Summary - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at end of training. When `False` (default), adapter weights are saved separately — use `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format. - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()` so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get no-op `finalize()` implementations. - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`, `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`, `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids: list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths` config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on `last_save_in_hf`. - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json` with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` / `remap_lora_keys_from_hf()` handle the bidirectional key translation. - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update `llama3_debugmodel_lora` with checkpoint settings for proper resumption. ### Test plan - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs partial planner) - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip` - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` — verify LoRA training runs end-to-end - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B LoRA config with PEFT save * #2484 [ghstack-poisoned]
…urce load" ### Summary - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at end of training. When `False` (default), adapter weights are saved separately — use `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format. - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()` so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get no-op `finalize()` implementations. - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`, `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`, `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids: list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths` config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on `last_save_in_hf`. - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json` with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` / `remap_lora_keys_from_hf()` handle the bidirectional key translation. - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update `llama3_debugmodel_lora` with checkpoint settings for proper resumption. ### Test plan - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs partial planner) - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip` - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` — verify LoRA training runs end-to-end - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B LoRA config with PEFT save * #2484 [ghstack-poisoned]
…ave and multi-source load" ### Summary - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at end of training. When `False` (default), adapter weights are saved separately — use `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format. - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()` so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get no-op `finalize()` implementations. - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`, `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`, `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids: list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths` config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on `last_save_in_hf`. - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json` with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` / `remap_lora_keys_from_hf()` handle the bidirectional key translation. - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update `llama3_debugmodel_lora` with checkpoint settings for proper resumption. ### Test plan - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs partial planner) - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip` - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` — verify LoRA training runs end-to-end - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B LoRA config with PEFT save * #2484 [ghstack-poisoned]
…urce load" ### Summary - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at end of training. When `False` (default), adapter weights are saved separately — use `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format. - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()` so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get no-op `finalize()` implementations. - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`, `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`, `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids: list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths` config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on `last_save_in_hf`. - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json` with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` / `remap_lora_keys_from_hf()` handle the bidirectional key translation. - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update `llama3_debugmodel_lora` with checkpoint settings for proper resumption. ### Test plan - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs partial planner) - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip` - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` — verify LoRA training runs end-to-end - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B LoRA config with PEFT save * #2484 [ghstack-poisoned]
Summary
FakeQuantizedLinear)
Support LoRA at TorchTitan #2515
Test Plan
Details
LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config:
The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters:
create_lora_linear(parent_cls)creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel)Float8Linearfirst, then in lora converter to convert fromFloat8LineartoLoRAFloat8Linearbycreate_lora_lineardescribed above. # converters = ["quantize.linear.float8", "lora"]LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has
base_out = super().forward(x), lora subclass won't touch the parent forward, so we could compose them togetherTest Plan
Tested with Llama3 debug model on 8 GPU
without lora converter
with lora converter
with lora converter + float8
Future work
There are some todo works that not included in this PR and will be future work
Stack from ghstack (oldest at bottom):