Skip to content
Open
88 changes: 88 additions & 0 deletions tests/unit_tests/test_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,26 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
import torch.nn as nn

from torchtitan.components.lora import LoRAConverter
from torchtitan.components.quantization.float8 import Float8LinearConverter
from torchtitan.config import ConfigManager
from torchtitan.distributed import ParallelDims
from torchtitan.protocols.model_converter import ModelConvertersContainer


class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(64, 64)
self.fc2 = nn.Linear(64, 64)

def forward(self, x):
return self.fc2(torch.relu(self.fc1(x)))


def build_parallel_dims(trainer_config, world_size):
parallelism_config = trainer_config.parallelism
parallel_dims = ParallelDims(
Expand Down Expand Up @@ -66,3 +79,78 @@ def test_build_model_converters_float8_converter():
assert isinstance(model_converters, ModelConvertersContainer)
assert len(model_converters.converters) == 1
assert isinstance(model_converters.converters[0], Float8LinearConverter)


def test_lora_freeze_and_trainability():
"""After convert: base params frozen, LoRA adapters present and trainable."""
model = SimpleModel()
converter = LoRAConverter(LoRAConverter.Config(rank=4, alpha=8.0))
converter.convert(model)

# LoRA adapters should be added to all linears
assert hasattr(model.fc1, "lora_a")
assert hasattr(model.fc1, "lora_b")
assert hasattr(model.fc2, "lora_a")
assert hasattr(model.fc2, "lora_b")

# Check every parameter
lora_param_names = []
base_param_names = []
for name, param in model.named_parameters():
if "lora_a" in name or "lora_b" in name:
lora_param_names.append(name)
assert param.requires_grad, f"LoRA param '{name}' should be trainable"
else:
base_param_names.append(name)
assert not param.requires_grad, f"Base param '{name}' should be frozen"

assert len(lora_param_names) > 0, "No LoRA params found"
assert len(base_param_names) > 0, "No base params found"


def test_lora_trains_base_frozen():
"""Train for several steps: LoRA params should change, base params should not."""
torch.manual_seed(42)
model = SimpleModel()
converter = LoRAConverter(LoRAConverter.Config(rank=4, alpha=8.0))
converter.convert(model)

# Snapshot all params before training
base_before = {
name: param.data.clone()
for name, param in model.named_parameters()
if "lora_a" not in name and "lora_b" not in name
}
lora_before = {
name: param.data.clone()
for name, param in model.named_parameters()
if "lora_a" in name or "lora_b" in name
}

# Only LoRA params go to optimizer
optimizer = torch.optim.SGD(
[p for p in model.parameters() if p.requires_grad], lr=0.1
)

# Train for 5 steps
for _ in range(5):
x = torch.randn(4, 64)
loss = model(x).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()

# Base params must not change
for name, param in model.named_parameters():
if name in base_before:
assert torch.equal(
param.data, base_before[name]
), f"Base param '{name}' changed during training"

# At least some LoRA params must change
any_lora_changed = any(
not torch.equal(param.data, lora_before[name])
for name, param in model.named_parameters()
if name in lora_before
)
assert any_lora_changed, "No LoRA param changed after 5 training steps"
123 changes: 123 additions & 0 deletions torchtitan/components/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn as nn

from torchtitan.config import Configurable
from torchtitan.tools.logging import logger

# Cache for dynamically created LoRA classes
_lora_class_cache: dict[type, type] = {}


def apply_lora(linear: nn.Linear, rank: int, alpha: float) -> nn.Linear:
parent_cls = type(linear)
assert issubclass(
parent_cls, nn.Linear
), f"parent_cls must be a subclass of nn.Linear, got {parent_cls}"

if parent_cls not in _lora_class_cache:

class LoRALinear(parent_cls): # type: ignore[valid-type, misc]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise RuntimeError("LoRALinear should not be instantiated directly.")

@classmethod
def from_linear(
cls, linear: nn.Linear, rank: int, alpha: float
) -> "LoRALinear":
linear.__class__ = cls
linear._init_lora(rank, alpha) # type: ignore[attr-defined]
return linear # type: ignore[return-value]

def _init_lora(
self,
rank: int,
alpha: float,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
self._lora_scaling = alpha / rank
device = device if device is not None else self.weight.device
dtype = dtype if dtype is not None else self.weight.dtype
self.lora_a = nn.Linear(
self.in_features,
rank,
bias=False,
device=device,
dtype=dtype,
)
self.lora_b = nn.Linear(
rank,
self.out_features,
bias=False,
device=device,
dtype=dtype,
)

def _init_weight(self) -> None:
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_b.weight)

def forward(self, input: torch.Tensor) -> torch.Tensor:
base_out = super().forward(input)
lora_out = self.lora_b(self.lora_a(input))
return base_out + self._lora_scaling * lora_out

LoRALinear.__name__ = f"LoRA{parent_cls.__name__}"
LoRALinear.__qualname__ = f"LoRA{parent_cls.__name__}"
_lora_class_cache[parent_cls] = LoRALinear

# pyrefly: ignore [missing-attribute]
return _lora_class_cache[parent_cls].from_linear(linear, rank, alpha)


class LoRAConverter(Configurable):
"""Apply LoRA adapters to all Linear layers in a model."""

@dataclass(kw_only=True, slots=True)
class Config(Configurable.Config):
rank: int = 8
"""Rank of the LoRA matrices (lora_a: in_features x rank, lora_b: rank x out_features)."""

alpha: float = 16.0
"""Scaling factor. Output is scaled by alpha/rank."""

def __init__(self, config: Config, **kwargs):
self.rank = config.rank
self.alpha = config.alpha
logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}")

def convert(self, model: nn.Module) -> None:
model.requires_grad_(False)
self._replace_linears_with_lora(model)

def _replace_linears_with_lora(self, module: nn.Module) -> None:
for _, child in list(module.named_modules()):
if isinstance(child, nn.Linear):
apply_lora(child, self.rank, self.alpha)

# Patch init_weights to also reinitialize LoRA adapters
original_init_weights = getattr(module, "init_weights", None)

def new_model_init_weights(*args: Any, **kwargs: Any) -> None:
if original_init_weights is not None and callable(original_init_weights):
original_init_weights(*args, **kwargs)
for sub_module in module.modules():
if type(sub_module) in _lora_class_cache.values():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we need this because nn.Linear doesn't have init_weights and instead in torchtitan we directly operating on the weight and bias in any nn.Linear submodule.

I believe @fegin will refactor this to create our own Linear module with init_weights capability. So let's leave a TODO here to remove this new_model_init_weights.

_init_weight = getattr(sub_module, "_init_weight", None)
assert _init_weight is not None and callable(_init_weight)
_init_weight()

object.__setattr__(module, "init_weights", new_model_init_weights)

def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None:
pass
14 changes: 14 additions & 0 deletions torchtitan/models/llama3/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.lora import LoRAConverter
from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.metrics import MetricsProcessor
from torchtitan.components.optimizer import (
Expand Down Expand Up @@ -108,6 +109,19 @@ def llama3_debugmodel_float8_emulate() -> Trainer.Config:
return config


def llama3_debugmodel_lora() -> Trainer.Config:
config = llama3_debugmodel()
config.model_converters = ModelConvertersContainer.Config(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Among all the converters (quantization, lora, etc.), order matters. For now let's create a __post_init__ in ModelConvertersContainer.Config to validate the order.

Note that it can't capture the wrong order, if we first init the main config and then modify a field.

@fegin Validation is becoming an issue. If we don't go with chz, maybe we should at least freeze the config, and require users to use dataclass.replace to trigger post init checks.

converters=[
LoRAConverter.Config(
rank=8,
alpha=16.0,
),
],
)
return config


def llama3_8b() -> Trainer.Config:
return Trainer.Config(
hf_assets_path="./assets/hf/Llama-3.1-8B",
Expand Down
Loading