Skip to content

[1/N] Add lora adapters for finetune#2484

Open
mori360 wants to merge 9 commits intogh/mori360/1/basefrom
gh/mori360/1/head
Open

[1/N] Add lora adapters for finetune#2484
mori360 wants to merge 9 commits intogh/mori360/1/basefrom
gh/mori360/1/head

Conversation

@mori360
Copy link
Contributor

@mori360 mori360 commented Mar 4, 2026

Summary

  • Add LoRAConverter model converter that applies low-rank adaptation (LoRA) to all nn.Linear layers in a model
  • LoRA uses dynamic subclass creation (class swap) to wrap existing Linear layers, preserving compatibility with any nn.Linear subclass (e.g.,
    FakeQuantizedLinear)
  • Base model weights are frozen (requires_grad=False), only LoRA adapter weights (lora_a, lora_b) are trainable
  • Patches init_weights on the model to reinitialize LoRA adapters during the meta-device → real-device init flow
  • Add llama3_debugmodel_lora debug config and unit tests for freeze/trainability and training correctness
    Support LoRA at TorchTitan #2515

Test Plan

  • test_lora_freeze_and_trainability: check LoRA params have requires_grad=True, base params (original weights/biases) have requires_grad=False.
  • test_lora_trains_base_frozen, after training, asserts every base param is bitwise identical to its snapshot. Asserts at least one LoRA param has changed.

Details

  1. Enable LoRA via ModelConverter
    LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config:
[model]
converters = ["lora"]

[lora]
rank = 8
alpha = 16.0
dropout = 0.0
  1. In-place Linear Replacement with LoRALinear Subclass
    The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters:
  • Dynamic subclassing: create_lora_linear(parent_cls) creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel)
  • In-place modification: Instead of creating new modules, we create subclass preserving existing weights
  • Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable
  • Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to Float8Linear first, then in lora converter to convert from Float8Linear to LoRAFloat8Linear by create_lora_linear described above. # converters = ["quantize.linear.float8", "lora"]
    LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has base_out = super().forward(x), lora subclass won't touch the parent forward, so we could compose them together

Test Plan

Tested with Llama3 debug model on 8 GPU

  • lora has loess memory cost
  • lora could work with float8

without lora converter

[rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step:  1  loss:  8.2000  grad_norm:  1.3955  memory:  1.01GiB(1.07%)  tps: 5,479  tflops: 0.39  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step:  5  loss:  5.3874  grad_norm:  2.3446  memory:  1.06GiB(1.11%)  tps: 343,441  tflops: 24.59  mfu: 2.49%
[rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10  loss:  4.1315  grad_norm:  1.8140  memory:  1.06GiB(1.11%)  tps: 320,740  tflops: 22.96  mfu: 2.32%

with lora converter

[rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step:  1  loss:  8.1225  grad_norm:  0.3293  memory:  0.98GiB(1.03%)  tps: 6,222  tflops: 0.46  mfu: 0.05%
[rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step:  5  loss:  8.0054  grad_norm:  0.5108  memory:  1.02GiB(1.07%)  tps: 274,052  tflops: 20.05  mfu: 2.03%
[rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10  loss:  7.8191  grad_norm:  0.7809  memory:  1.02GiB(1.07%)  tps: 266,976  tflops: 19.53  mfu: 1.98%

with lora converter + float8

[rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step:  1  loss:  8.1689  grad_norm:  0.2351  memory:  1.00GiB(1.05%)  tps: 5,067  tflops: 0.37  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step:  5  loss:  8.0590  grad_norm:  0.3869  memory:  1.04GiB(1.10%)  tps: 134,235  tflops: 9.82  mfu: 0.99%
[rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10  loss:  7.9155  grad_norm:  0.6083  memory:  1.04GiB(1.10%)  tps: 122,904  tflops: 8.99  mfu: 0.91%

Future work

There are some todo works that not included in this PR and will be future work

  1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion:
  • Keep lora_adapters as Replicate so that they won't be shardded by TP
  • Generate layer_plan for lora adapters so that they will be processed by TP as well
  1. Support supervised finetuning

Stack from ghstack (oldest at bottom):

[ghstack-poisoned]
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 4, 2026
@mori360 mori360 changed the title add lora adapters [1/2] Add lora adapters for finetune Mar 4, 2026
mori360 added 2 commits March 5, 2026 11:41
### Summary
To support LoRA (Low-Rank Adaptation) in torchtitan, this PR is the first step that adds LoRA adapters into the model and enables LoRA finetuning.

LoRALinear - A dynamically generated Linear layer subclass that adds trainable low-rank adapters (lora_a, lora_b) while keeping base weights frozen. Uses the decomposition output = base_output + (α/r) * B(A(x)) where r is rank and α is scaling factor. Compatible with any nn.Linear subclass (including Float8Linear).

LoRAConverter - A model converter that:
- Recursively replaces all nn.Linear layers with their LoRA equivalents
- Freezes base model weights (requires_grad=False)
- Overrides init_weights() to properly initialize LoRA adapters on the target device

### Details
1. Enable LoRA via ModelConverter
LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config:
```
[model]
converters = ["lora"]

[lora]
rank = 8
alpha = 16.0
dropout = 0.0
```

2. In-place Linear Replacement with LoRALinear Subclass
The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters:

- Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel)
- In-place modification: Instead of creating new modules, we create subclass preserving existing weights
- Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable
- Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"]
    LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together

### Test Plan
Tested with Llama3 debug model on 8 GPU
- lora has loess memory cost
- lora could work with float8

without lora converter
```
[rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step:  1  loss:  8.2000  grad_norm:  1.3955  memory:  1.01GiB(1.07%)  tps: 5,479  tflops: 0.39  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step:  5  loss:  5.3874  grad_norm:  2.3446  memory:  1.06GiB(1.11%)  tps: 343,441  tflops: 24.59  mfu: 2.49%
[rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10  loss:  4.1315  grad_norm:  1.8140  memory:  1.06GiB(1.11%)  tps: 320,740  tflops: 22.96  mfu: 2.32%
```


with lora converter
```
[rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step:  1  loss:  8.1225  grad_norm:  0.3293  memory:  0.98GiB(1.03%)  tps: 6,222  tflops: 0.46  mfu: 0.05%
[rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step:  5  loss:  8.0054  grad_norm:  0.5108  memory:  1.02GiB(1.07%)  tps: 274,052  tflops: 20.05  mfu: 2.03%
[rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10  loss:  7.8191  grad_norm:  0.7809  memory:  1.02GiB(1.07%)  tps: 266,976  tflops: 19.53  mfu: 1.98%
```

with lora converter + float8
```
[rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step:  1  loss:  8.1689  grad_norm:  0.2351  memory:  1.00GiB(1.05%)  tps: 5,067  tflops: 0.37  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step:  5  loss:  8.0590  grad_norm:  0.3869  memory:  1.04GiB(1.10%)  tps: 134,235  tflops: 9.82  mfu: 0.99%
[rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10  loss:  7.9155  grad_norm:  0.6083  memory:  1.04GiB(1.10%)  tps: 122,904  tflops: 8.99  mfu: 0.91%
```


### Future work
There are some todo works that not included in this PR and will be future work
1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion:
- Keep lora_adapters as Replicate so that they won't be shardded by TP
- Generate layer_plan for lora adapters so that they will be processed by TP as well
2. Support supervised finetuning




[ghstack-poisoned]
### Summary
To support LoRA (Low-Rank Adaptation) in torchtitan, this PR is the first step that adds LoRA adapters into the model and enables LoRA finetuning.

LoRALinear - A dynamically generated Linear layer subclass that adds trainable low-rank adapters (lora_a, lora_b) while keeping base weights frozen. Uses the decomposition output = base_output + (α/r) * B(A(x)) where r is rank and α is scaling factor. Compatible with any nn.Linear subclass (including Float8Linear).

LoRAConverter - A model converter that:
- Recursively replaces all nn.Linear layers with their LoRA equivalents
- Freezes base model weights (requires_grad=False)
- Overrides init_weights() to properly initialize LoRA adapters on the target device

### Details
1. Enable LoRA via ModelConverter
LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config:
```
[model]
converters = ["lora"]

[lora]
rank = 8
alpha = 16.0
dropout = 0.0
```

2. In-place Linear Replacement with LoRALinear Subclass
The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters:

- Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel)
- In-place modification: Instead of creating new modules, we create subclass preserving existing weights
- Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable
- Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"]
    LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together

### Test Plan
Tested with Llama3 debug model on 8 GPU
- lora has loess memory cost
- lora could work with float8

without lora converter
```
[rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step:  1  loss:  8.2000  grad_norm:  1.3955  memory:  1.01GiB(1.07%)  tps: 5,479  tflops: 0.39  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step:  5  loss:  5.3874  grad_norm:  2.3446  memory:  1.06GiB(1.11%)  tps: 343,441  tflops: 24.59  mfu: 2.49%
[rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10  loss:  4.1315  grad_norm:  1.8140  memory:  1.06GiB(1.11%)  tps: 320,740  tflops: 22.96  mfu: 2.32%
```


with lora converter
```
[rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step:  1  loss:  8.1225  grad_norm:  0.3293  memory:  0.98GiB(1.03%)  tps: 6,222  tflops: 0.46  mfu: 0.05%
[rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step:  5  loss:  8.0054  grad_norm:  0.5108  memory:  1.02GiB(1.07%)  tps: 274,052  tflops: 20.05  mfu: 2.03%
[rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10  loss:  7.8191  grad_norm:  0.7809  memory:  1.02GiB(1.07%)  tps: 266,976  tflops: 19.53  mfu: 1.98%
```

with lora converter + float8
```
[rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step:  1  loss:  8.1689  grad_norm:  0.2351  memory:  1.00GiB(1.05%)  tps: 5,067  tflops: 0.37  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step:  5  loss:  8.0590  grad_norm:  0.3869  memory:  1.04GiB(1.10%)  tps: 134,235  tflops: 9.82  mfu: 0.99%
[rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10  loss:  7.9155  grad_norm:  0.6083  memory:  1.04GiB(1.10%)  tps: 122,904  tflops: 8.99  mfu: 0.91%
```


### Future work
There are some todo works that not included in this PR and will be future work
1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion:
- Keep lora_adapters as Replicate so that they won't be shardded by TP
- Generate layer_plan for lora adapters so that they will be processed by TP as well
2. Support supervised finetuning




[ghstack-poisoned]
@mori360 mori360 changed the title [1/2] Add lora adapters for finetune [1/N] Add lora adapters for finetune Mar 6, 2026
@mori360 mori360 added this to the New Feature, Model, Misc milestone Mar 6, 2026
@mori360 mori360 self-assigned this Mar 6, 2026
if original_init_weights is not None and callable(original_init_weights):
original_init_weights(*args, **kwargs)
for sub_module in module.modules():
if type(sub_module) in _lora_class_cache.values():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we need this because nn.Linear doesn't have init_weights and instead in torchtitan we directly operating on the weight and bias in any nn.Linear submodule.

I believe @fegin will refactor this to create our own Linear module with init_weights capability. So let's leave a TODO here to remove this new_model_init_weights.


def llama3_debugmodel_lora() -> Trainer.Config:
config = llama3_debugmodel()
config.model_converters = ModelConvertersContainer.Config(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Among all the converters (quantization, lora, etc.), order matters. For now let's create a __post_init__ in ModelConvertersContainer.Config to validate the order.

Note that it can't capture the wrong order, if we first init the main config and then modify a field.

@fegin Validation is becoming an issue. If we don't go with chz, maybe we should at least freeze the config, and require users to use dataclass.replace to trigger post init checks.

mori360 added 4 commits March 12, 2026 12:45
### Summary
  - Add LoRAConverter model converter that applies low-rank adaptation (LoRA) to all nn.Linear layers in a model
  - LoRA uses dynamic subclass creation (__class__ swap) to wrap existing Linear layers, preserving compatibility with any nn.Linear subclass (e.g.,
  FakeQuantizedLinear)
  - Base model weights are frozen (requires_grad=False), only LoRA adapter weights (lora_a, lora_b) are trainable
  - Patches init_weights on the model to reinitialize LoRA adapters during the meta-device → real-device init flow
  - Add llama3_debugmodel_lora debug config and unit tests for freeze/trainability and training correctness
#2515 

### Test Plan
 - test_lora_freeze_and_trainability: check LoRA params have requires_grad=True, base params (original weights/biases) have requires_grad=False.
 - test_lora_trains_base_frozen, after training, asserts every base param is bitwise identical to its snapshot. Asserts at least one LoRA param has changed.


### Details
1. Enable LoRA via ModelConverter
LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config:
```
[model]
converters = ["lora"]

[lora]
rank = 8
alpha = 16.0
dropout = 0.0
```

2. In-place Linear Replacement with LoRALinear Subclass
The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters:

- Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel)
- In-place modification: Instead of creating new modules, we create subclass preserving existing weights
- Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable
- Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"]
    LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together

### Test Plan
Tested with Llama3 debug model on 8 GPU
- lora has loess memory cost
- lora could work with float8

without lora converter
```
[rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step:  1  loss:  8.2000  grad_norm:  1.3955  memory:  1.01GiB(1.07%)  tps: 5,479  tflops: 0.39  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step:  5  loss:  5.3874  grad_norm:  2.3446  memory:  1.06GiB(1.11%)  tps: 343,441  tflops: 24.59  mfu: 2.49%
[rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10  loss:  4.1315  grad_norm:  1.8140  memory:  1.06GiB(1.11%)  tps: 320,740  tflops: 22.96  mfu: 2.32%
```


with lora converter
```
[rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step:  1  loss:  8.1225  grad_norm:  0.3293  memory:  0.98GiB(1.03%)  tps: 6,222  tflops: 0.46  mfu: 0.05%
[rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step:  5  loss:  8.0054  grad_norm:  0.5108  memory:  1.02GiB(1.07%)  tps: 274,052  tflops: 20.05  mfu: 2.03%
[rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10  loss:  7.8191  grad_norm:  0.7809  memory:  1.02GiB(1.07%)  tps: 266,976  tflops: 19.53  mfu: 1.98%
```

with lora converter + float8
```
[rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step:  1  loss:  8.1689  grad_norm:  0.2351  memory:  1.00GiB(1.05%)  tps: 5,067  tflops: 0.37  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step:  5  loss:  8.0590  grad_norm:  0.3869  memory:  1.04GiB(1.10%)  tps: 134,235  tflops: 9.82  mfu: 0.99%
[rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10  loss:  7.9155  grad_norm:  0.6083  memory:  1.04GiB(1.10%)  tps: 122,904  tflops: 8.99  mfu: 0.91%
```


### Future work
There are some todo works that not included in this PR and will be future work
1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion:
- Keep lora_adapters as Replicate so that they won't be shardded by TP
- Generate layer_plan for lora adapters so that they will be processed by TP as well
2. Support supervised finetuning




[ghstack-poisoned]
### Summary
  - Add LoRAConverter model converter that applies low-rank adaptation (LoRA) to all nn.Linear layers in a model
  - LoRA uses dynamic subclass creation (__class__ swap) to wrap existing Linear layers, preserving compatibility with any nn.Linear subclass (e.g.,
  FakeQuantizedLinear)
  - Base model weights are frozen (requires_grad=False), only LoRA adapter weights (lora_a, lora_b) are trainable
  - Patches init_weights on the model to reinitialize LoRA adapters during the meta-device → real-device init flow
  - Add llama3_debugmodel_lora debug config and unit tests for freeze/trainability and training correctness
#2515 

### Test Plan
 - test_lora_freeze_and_trainability: check LoRA params have requires_grad=True, base params (original weights/biases) have requires_grad=False.
 - test_lora_trains_base_frozen, after training, asserts every base param is bitwise identical to its snapshot. Asserts at least one LoRA param has changed.


### Details
1. Enable LoRA via ModelConverter
LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config:
```
[model]
converters = ["lora"]

[lora]
rank = 8
alpha = 16.0
dropout = 0.0
```

2. In-place Linear Replacement with LoRALinear Subclass
The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters:

- Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel)
- In-place modification: Instead of creating new modules, we create subclass preserving existing weights
- Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable
- Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"]
    LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together

### Test Plan
Tested with Llama3 debug model on 8 GPU
- lora has loess memory cost
- lora could work with float8

without lora converter
```
[rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step:  1  loss:  8.2000  grad_norm:  1.3955  memory:  1.01GiB(1.07%)  tps: 5,479  tflops: 0.39  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step:  5  loss:  5.3874  grad_norm:  2.3446  memory:  1.06GiB(1.11%)  tps: 343,441  tflops: 24.59  mfu: 2.49%
[rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10  loss:  4.1315  grad_norm:  1.8140  memory:  1.06GiB(1.11%)  tps: 320,740  tflops: 22.96  mfu: 2.32%
```


with lora converter
```
[rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step:  1  loss:  8.1225  grad_norm:  0.3293  memory:  0.98GiB(1.03%)  tps: 6,222  tflops: 0.46  mfu: 0.05%
[rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step:  5  loss:  8.0054  grad_norm:  0.5108  memory:  1.02GiB(1.07%)  tps: 274,052  tflops: 20.05  mfu: 2.03%
[rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10  loss:  7.8191  grad_norm:  0.7809  memory:  1.02GiB(1.07%)  tps: 266,976  tflops: 19.53  mfu: 1.98%
```

with lora converter + float8
```
[rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step:  1  loss:  8.1689  grad_norm:  0.2351  memory:  1.00GiB(1.05%)  tps: 5,067  tflops: 0.37  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step:  5  loss:  8.0590  grad_norm:  0.3869  memory:  1.04GiB(1.10%)  tps: 134,235  tflops: 9.82  mfu: 0.99%
[rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10  loss:  7.9155  grad_norm:  0.6083  memory:  1.04GiB(1.10%)  tps: 122,904  tflops: 8.99  mfu: 0.91%
```


### Future work
There are some todo works that not included in this PR and will be future work
1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion:
- Keep lora_adapters as Replicate so that they won't be shardded by TP
- Generate layer_plan for lora adapters so that they will be processed by TP as well
2. Support supervised finetuning




[ghstack-poisoned]
### Summary
  - Add LoRAConverter model converter that applies low-rank adaptation (LoRA) to all nn.Linear layers in a model
  - LoRA uses dynamic subclass creation (__class__ swap) to wrap existing Linear layers, preserving compatibility with any nn.Linear subclass (e.g.,
  FakeQuantizedLinear)
  - Base model weights are frozen (requires_grad=False), only LoRA adapter weights (lora_a, lora_b) are trainable
  - Patches init_weights on the model to reinitialize LoRA adapters during the meta-device → real-device init flow
  - Add llama3_debugmodel_lora debug config and unit tests for freeze/trainability and training correctness
#2515 

### Test Plan
 - test_lora_freeze_and_trainability: check LoRA params have requires_grad=True, base params (original weights/biases) have requires_grad=False.
 - test_lora_trains_base_frozen, after training, asserts every base param is bitwise identical to its snapshot. Asserts at least one LoRA param has changed.


### Details
1. Enable LoRA via ModelConverter
LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config:
```
[model]
converters = ["lora"]

[lora]
rank = 8
alpha = 16.0
dropout = 0.0
```

2. In-place Linear Replacement with LoRALinear Subclass
The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters:

- Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel)
- In-place modification: Instead of creating new modules, we create subclass preserving existing weights
- Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable
- Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"]
    LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together

### Test Plan
Tested with Llama3 debug model on 8 GPU
- lora has loess memory cost
- lora could work with float8

without lora converter
```
[rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step:  1  loss:  8.2000  grad_norm:  1.3955  memory:  1.01GiB(1.07%)  tps: 5,479  tflops: 0.39  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step:  5  loss:  5.3874  grad_norm:  2.3446  memory:  1.06GiB(1.11%)  tps: 343,441  tflops: 24.59  mfu: 2.49%
[rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10  loss:  4.1315  grad_norm:  1.8140  memory:  1.06GiB(1.11%)  tps: 320,740  tflops: 22.96  mfu: 2.32%
```


with lora converter
```
[rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step:  1  loss:  8.1225  grad_norm:  0.3293  memory:  0.98GiB(1.03%)  tps: 6,222  tflops: 0.46  mfu: 0.05%
[rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step:  5  loss:  8.0054  grad_norm:  0.5108  memory:  1.02GiB(1.07%)  tps: 274,052  tflops: 20.05  mfu: 2.03%
[rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10  loss:  7.8191  grad_norm:  0.7809  memory:  1.02GiB(1.07%)  tps: 266,976  tflops: 19.53  mfu: 1.98%
```

with lora converter + float8
```
[rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step:  1  loss:  8.1689  grad_norm:  0.2351  memory:  1.00GiB(1.05%)  tps: 5,067  tflops: 0.37  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step:  5  loss:  8.0590  grad_norm:  0.3869  memory:  1.04GiB(1.10%)  tps: 134,235  tflops: 9.82  mfu: 0.99%
[rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10  loss:  7.9155  grad_norm:  0.6083  memory:  1.04GiB(1.10%)  tps: 122,904  tflops: 8.99  mfu: 0.91%
```


### Future work
There are some todo works that not included in this PR and will be future work
1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion:
- Keep lora_adapters as Replicate so that they won't be shardded by TP
- Generate layer_plan for lora adapters so that they will be processed by TP as well
2. Support supervised finetuning




[ghstack-poisoned]
### Summary
  - Add LoRAConverter model converter that applies low-rank adaptation (LoRA) to all nn.Linear layers in a model
  - LoRA uses dynamic subclass creation (__class__ swap) to wrap existing Linear layers, preserving compatibility with any nn.Linear subclass (e.g.,
  FakeQuantizedLinear)
  - Base model weights are frozen (requires_grad=False), only LoRA adapter weights (lora_a, lora_b) are trainable
  - Patches init_weights on the model to reinitialize LoRA adapters during the meta-device → real-device init flow
  - Add llama3_debugmodel_lora debug config and unit tests for freeze/trainability and training correctness
#2515 

### Test Plan
 - test_lora_freeze_and_trainability: check LoRA params have requires_grad=True, base params (original weights/biases) have requires_grad=False.
 - test_lora_trains_base_frozen, after training, asserts every base param is bitwise identical to its snapshot. Asserts at least one LoRA param has changed.


### Details
1. Enable LoRA via ModelConverter
LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config:
```
[model]
converters = ["lora"]

[lora]
rank = 8
alpha = 16.0
dropout = 0.0
```

2. In-place Linear Replacement with LoRALinear Subclass
The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters:

- Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel)
- In-place modification: Instead of creating new modules, we create subclass preserving existing weights
- Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable
- Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"]
    LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together

### Test Plan
Tested with Llama3 debug model on 8 GPU
- lora has loess memory cost
- lora could work with float8

without lora converter
```
[rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step:  1  loss:  8.2000  grad_norm:  1.3955  memory:  1.01GiB(1.07%)  tps: 5,479  tflops: 0.39  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step:  5  loss:  5.3874  grad_norm:  2.3446  memory:  1.06GiB(1.11%)  tps: 343,441  tflops: 24.59  mfu: 2.49%
[rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10  loss:  4.1315  grad_norm:  1.8140  memory:  1.06GiB(1.11%)  tps: 320,740  tflops: 22.96  mfu: 2.32%
```


with lora converter
```
[rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step:  1  loss:  8.1225  grad_norm:  0.3293  memory:  0.98GiB(1.03%)  tps: 6,222  tflops: 0.46  mfu: 0.05%
[rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step:  5  loss:  8.0054  grad_norm:  0.5108  memory:  1.02GiB(1.07%)  tps: 274,052  tflops: 20.05  mfu: 2.03%
[rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10  loss:  7.8191  grad_norm:  0.7809  memory:  1.02GiB(1.07%)  tps: 266,976  tflops: 19.53  mfu: 1.98%
```

with lora converter + float8
```
[rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step:  1  loss:  8.1689  grad_norm:  0.2351  memory:  1.00GiB(1.05%)  tps: 5,067  tflops: 0.37  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step:  5  loss:  8.0590  grad_norm:  0.3869  memory:  1.04GiB(1.10%)  tps: 134,235  tflops: 9.82  mfu: 0.99%
[rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10  loss:  7.9155  grad_norm:  0.6083  memory:  1.04GiB(1.10%)  tps: 122,904  tflops: 8.99  mfu: 0.91%
```


### Future work
There are some todo works that not included in this PR and will be future work
1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion:
- Keep lora_adapters as Replicate so that they won't be shardded by TP
- Generate layer_plan for lora adapters so that they will be processed by TP as well
2. Support supervised finetuning




[ghstack-poisoned]
@mori360 mori360 requested a review from tianyu-l March 13, 2026 00:35
)

def init_weights(self, **kwargs) -> None:
super().init_weights(**kwargs) # pyrefly: ignore [not-callable]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove new_model_init_weights after refactoring Linear @fegin

@mori360 mori360 mentioned this pull request Mar 13, 2026
@mori360
Copy link
Contributor Author

mori360 commented Mar 17, 2026

We plan to hold on this pr until the pr to refactor parallelism next week, and fix composability issue to tp that time. (will support ep in the pr later)

### Summary
  - Add LoRAConverter model converter that applies low-rank adaptation (LoRA) to all nn.Linear layers in a model
  - LoRA uses dynamic subclass creation (__class__ swap) to wrap existing Linear layers, preserving compatibility with any nn.Linear subclass (e.g.,
  FakeQuantizedLinear)
  - Base model weights are frozen (requires_grad=False), only LoRA adapter weights (lora_a, lora_b) are trainable
  - Patches init_weights on the model to reinitialize LoRA adapters during the meta-device → real-device init flow
  - Add llama3_debugmodel_lora debug config and unit tests for freeze/trainability and training correctness
#2515 

### Test Plan
 - test_lora_freeze_and_trainability: check LoRA params have requires_grad=True, base params (original weights/biases) have requires_grad=False.
 - test_lora_trains_base_frozen, after training, asserts every base param is bitwise identical to its snapshot. Asserts at least one LoRA param has changed.


### Details
1. Enable LoRA via ModelConverter
LoRA is integrated using the existing ModelConverter protocol, making it consistent with other model transformations (e.g., quantization). Users can enable LoRA by adding "lora" to the model.converters config:
```
[model]
converters = ["lora"]

[lora]
rank = 8
alpha = 16.0
dropout = 0.0
```

2. In-place Linear Replacement with LoRALinear Subclass
The implementation uses dynamic class creation via create_lora_linear() to wrap any nn.Linear subclass with LoRA adapters:

- Dynamic subclassing: `create_lora_linear(parent_cls) `creates a new class that inherits from the original Linear class, preserving custom forward logic (e.g., RowwiseParallel, ColwiseParallel)
- In-place modification: Instead of creating new modules, we create subclass preserving existing weights
- Frozen base weights: All base model parameters are frozen (requires_grad=False), only LoRA adapters are trainable
- Composability with quantization: we currently has float8 low-precision quantization as another choice in converters. Here we will conduct convert in order that convert to `Float8Linear` first, then in lora converter to convert from `Float8Linear` to `LoRAFloat8Linear` by `create_lora_linear` described above. # converters = ["quantize.linear.float8", "lora"]
    LoRA architecture is quite simple, that only adds lora_out to the base_out. So that as long as it has `base_out = super().forward(x)`, lora subclass won't touch the parent forward, so we could compose them together

### Test Plan
Tested with Llama3 debug model on 8 GPU
- lora has loess memory cost
- lora could work with float8

without lora converter
```
[rank0]:[titan] 2026-02-02 14:41:32,656 - root - INFO - step:  1  loss:  8.2000  grad_norm:  1.3955  memory:  1.01GiB(1.07%)  tps: 5,479  tflops: 0.39  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:41:33,139 - root - INFO - step:  5  loss:  5.3874  grad_norm:  2.3446  memory:  1.06GiB(1.11%)  tps: 343,441  tflops: 24.59  mfu: 2.49%
[rank0]:[titan] 2026-02-02 14:41:33,396 - root - INFO - step: 10  loss:  4.1315  grad_norm:  1.8140  memory:  1.06GiB(1.11%)  tps: 320,740  tflops: 22.96  mfu: 2.32%
```


with lora converter
```
[rank0]:[titan] 2026-02-02 14:42:44,379 - root - INFO - step:  1  loss:  8.1225  grad_norm:  0.3293  memory:  0.98GiB(1.03%)  tps: 6,222  tflops: 0.46  mfu: 0.05%
[rank0]:[titan] 2026-02-02 14:42:44,687 - root - INFO - step:  5  loss:  8.0054  grad_norm:  0.5108  memory:  1.02GiB(1.07%)  tps: 274,052  tflops: 20.05  mfu: 2.03%
[rank0]:[titan] 2026-02-02 14:42:44,988 - root - INFO - step: 10  loss:  7.8191  grad_norm:  0.7809  memory:  1.02GiB(1.07%)  tps: 266,976  tflops: 19.53  mfu: 1.98%
```

with lora converter + float8
```
[rank0]:[titan] 2026-02-02 14:43:38,608 - root - INFO - step:  1  loss:  8.1689  grad_norm:  0.2351  memory:  1.00GiB(1.05%)  tps: 5,067  tflops: 0.37  mfu: 0.04%
[rank0]:[titan] 2026-02-02 14:43:39,152 - root - INFO - step:  5  loss:  8.0590  grad_norm:  0.3869  memory:  1.04GiB(1.10%)  tps: 134,235  tflops: 9.82  mfu: 0.99%
[rank0]:[titan] 2026-02-02 14:43:39,792 - root - INFO - step: 10  loss:  7.9155  grad_norm:  0.6083  memory:  1.04GiB(1.10%)  tps: 122,904  tflops: 8.99  mfu: 0.91%
```


### Future work
There are some todo works that not included in this PR and will be future work
1. Composable with TP, the current plan does not work with TP, here are 2 plans we may try later and need further discussion:
- Keep lora_adapters as Replicate so that they won't be shardded by TP
- Generate layer_plan for lora adapters so that they will be processed by TP as well
2. Support supervised finetuning




[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 20, 2026
…ave and multi-source load"

 ### Summary

  - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on
  `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at
  end of training. When `False` (default), adapter weights are saved separately — use
  `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format.

  - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last
  checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before
  quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()`
  so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get
  no-op `finalize()` implementations.

  - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`,
  `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`,
  `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids:
  list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths`
  config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on
  `last_save_in_hf`.

  - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json`
   with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` /
  `remap_lora_keys_from_hf()` handle the bidirectional key translation.

  - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update
  `llama3_debugmodel_lora` with checkpoint settings for proper resumption.

  ### Test plan

  - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs
   partial planner)
  - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip`
  - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` —
  verify LoRA training runs end-to-end
  - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B
  LoRA config with PEFT save


* #2484

[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 20, 2026
…urce load"

 ### Summary

  - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on
  `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at
  end of training. When `False` (default), adapter weights are saved separately — use
  `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format.

  - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last
  checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before
  quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()`
  so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get
  no-op `finalize()` implementations.

  - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`,
  `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`,
  `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids:
  list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths`
  config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on
  `last_save_in_hf`.

  - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json`
   with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` /
  `remap_lora_keys_from_hf()` handle the bidirectional key translation.

  - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update
  `llama3_debugmodel_lora` with checkpoint settings for proper resumption.

  ### Test plan

  - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs
   partial planner)
  - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip`
  - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` —
  verify LoRA training runs end-to-end
  - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B
  LoRA config with PEFT save


* #2484

[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 20, 2026
…ave and multi-source load"

 ### Summary

  - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on
  `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at
  end of training. When `False` (default), adapter weights are saved separately — use
  `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format.

  - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last
  checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before
  quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()`
  so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get
  no-op `finalize()` implementations.

  - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`,
  `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`,
  `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids:
  list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths`
  config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on
  `last_save_in_hf`.

  - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json`
   with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` /
  `remap_lora_keys_from_hf()` handle the bidirectional key translation.

  - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update
  `llama3_debugmodel_lora` with checkpoint settings for proper resumption.

  ### Test plan

  - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs
   partial planner)
  - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip`
  - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` —
  verify LoRA training runs end-to-end
  - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B
  LoRA config with PEFT save


* #2484

[ghstack-poisoned]
mori360 added a commit that referenced this pull request Mar 20, 2026
…urce load"

 ### Summary

  - **`merge_adapter` config**: Replace `save_format: str` with `merge_adapter: bool` on
  `LoRAConverter.Config`. When `True`, adapters are folded into base weights (`base + alpha/rank * B @ A`) at
  end of training. When `False` (default), adapter weights are saved separately — use
  `checkpoint.last_save_in_hf=True` to save in HuggingFace PEFT format.

  - **`finalize()` lifecycle on ModelConverter protocol**: Add an end-of-training hook called before the last
  checkpoint save. `ModelConvertersContainer` runs finalize in reverse converter order (LoRA merge before
  quantization CONVERT). A `converter_finalize_fn` closure is attached to each model part during `convert()`
  so the checkpoint system can invoke it. All existing converters (Float8Linear, Float8GroupedMM, MXFP8) get
  no-op `finalize()` implementations.

  - **Checkpoint integration**: `ModelWrapper` gains converter-aware methods (`state_dict_to_save`,
  `export_state_dict`, `base_state_dict`, `has_converter_keys`, `converter_save_last_fn`,
  `converter_load_additional_fn`) to support adapter-only checkpointing. `dcp_load` accepts `checkpoint_ids:
  list[str]` for loading from multiple sources (base model + adapter weights). A new `additional_load_paths`
  config field enables multi-source loading. The PEFT save path in `_save_last_step` is gated on
  `last_save_in_hf`.

  - **LoRA PEFT save/load**: `_make_peft_save_fn()` writes `adapter_model.safetensors` + `adapter_config.json`
   with HF PEFT key naming. `_make_peft_load_fn()` loads and remaps keys back. `remap_lora_keys_to_hf()` /
  `remap_lora_keys_from_hf()` handle the bidirectional key translation.

  - **Configs**: Add `llama3_8b_lora` (rank=128, alpha=32, `last_save_in_hf=True`). Update
  `llama3_debugmodel_lora` with checkpoint settings for proper resumption.

  ### Test plan

  - [ ] `pytest tests/unit_tests/test_checkpoint.py -x` — new `TestModelWrapperConverterKeys` tests (strict vs
   partial planner)
  - [ ] `pytest tests/unit_tests/test_model_converter.py -x` — new `test_lora_key_remap_roundtrip`
  - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_debugmodel_lora` —
  verify LoRA training runs end-to-end
  - [ ] `torchrun --nproc_per_node=4 -m torchtitan.train --module llama3 --config llama3_8b_lora` — verify 8B
  LoRA config with PEFT save


* #2484

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants