|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import math |
| 8 | +from dataclasses import dataclass |
| 9 | +from typing import Any |
| 10 | + |
| 11 | +import torch |
| 12 | +import torch.nn as nn |
| 13 | + |
| 14 | +from torchtitan.config import Configurable |
| 15 | +from torchtitan.tools.logging import logger |
| 16 | + |
| 17 | +# Cache for dynamically created LoRA classes |
| 18 | +_lora_class_cache: dict[type, type] = {} |
| 19 | + |
| 20 | + |
| 21 | +def apply_lora(linear: nn.Linear, rank: int, alpha: float) -> nn.Linear: |
| 22 | + parent_cls = type(linear) |
| 23 | + assert issubclass( |
| 24 | + parent_cls, nn.Linear |
| 25 | + ), f"parent_cls must be a subclass of nn.Linear, got {parent_cls}" |
| 26 | + |
| 27 | + if parent_cls not in _lora_class_cache: |
| 28 | + |
| 29 | + class LoRALinear(parent_cls): # type: ignore[valid-type, misc] |
| 30 | + def __init__(self, *args: Any, **kwargs: Any) -> None: |
| 31 | + raise RuntimeError("LoRALinear should not be instantiated directly.") |
| 32 | + |
| 33 | + @classmethod |
| 34 | + def from_linear( |
| 35 | + cls, linear: nn.Linear, rank: int, alpha: float |
| 36 | + ) -> "LoRALinear": |
| 37 | + linear.__class__ = cls |
| 38 | + linear._init_lora(rank, alpha) # type: ignore[attr-defined] |
| 39 | + return linear # type: ignore[return-value] |
| 40 | + |
| 41 | + def _init_lora( |
| 42 | + self, |
| 43 | + rank: int, |
| 44 | + alpha: float, |
| 45 | + device: torch.device | None = None, |
| 46 | + dtype: torch.dtype | None = None, |
| 47 | + ) -> None: |
| 48 | + self._lora_scaling = alpha / rank |
| 49 | + device = device if device is not None else self.weight.device |
| 50 | + dtype = dtype if dtype is not None else self.weight.dtype |
| 51 | + self.lora_a = nn.Linear( |
| 52 | + self.in_features, |
| 53 | + rank, |
| 54 | + bias=False, |
| 55 | + device=device, |
| 56 | + dtype=dtype, |
| 57 | + ) |
| 58 | + self.lora_b = nn.Linear( |
| 59 | + rank, |
| 60 | + self.out_features, |
| 61 | + bias=False, |
| 62 | + device=device, |
| 63 | + dtype=dtype, |
| 64 | + ) |
| 65 | + self._init_weight() |
| 66 | + |
| 67 | + def _init_weight(self) -> None: |
| 68 | + nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) |
| 69 | + nn.init.zeros_(self.lora_b.weight) |
| 70 | + |
| 71 | + def forward(self, input: torch.Tensor) -> torch.Tensor: |
| 72 | + base_out = super().forward(input) |
| 73 | + lora_out = self.lora_b(self.lora_a(input)) |
| 74 | + return base_out + self._lora_scaling * lora_out |
| 75 | + |
| 76 | + LoRALinear.__name__ = f"LoRA{parent_cls.__name__}" |
| 77 | + LoRALinear.__qualname__ = f"LoRA{parent_cls.__name__}" |
| 78 | + _lora_class_cache[parent_cls] = LoRALinear |
| 79 | + |
| 80 | + return _lora_class_cache[parent_cls].from_linear(linear, rank, alpha) |
| 81 | + |
| 82 | +class LoRAConverter(Configurable): |
| 83 | + """Apply LoRA adapters to all Linear layers in a model.""" |
| 84 | + |
| 85 | + @dataclass(kw_only=True, slots=True) |
| 86 | + class Config(Configurable.Config): |
| 87 | + rank: int = 8 |
| 88 | + """Rank of the LoRA matrices (lora_a: in_features x rank, lora_b: rank x out_features).""" |
| 89 | + |
| 90 | + alpha: float = 16.0 |
| 91 | + """Scaling factor. Output is scaled by alpha/rank.""" |
| 92 | + |
| 93 | + def __init__(self, config: Config, **kwargs): |
| 94 | + self.rank = config.rank |
| 95 | + self.alpha = config.alpha |
| 96 | + logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}") |
| 97 | + |
| 98 | + def convert(self, model: nn.Module) -> None: |
| 99 | + model.requires_grad_(False) |
| 100 | + self._replace_linears_with_lora(model) |
| 101 | + |
| 102 | + def _replace_linears_with_lora(self, module: nn.Module) -> None: |
| 103 | + for _, child in list(module.named_modules()): |
| 104 | + if isinstance(child, nn.Linear): |
| 105 | + apply_lora(child, self.rank, self.alpha) |
| 106 | + |
| 107 | + # Patch init_weights to also reinitialize LoRA adapters |
| 108 | + original_init_weights = getattr(module, "init_weights", None) |
| 109 | + |
| 110 | + def new_model_init_weights(*args: Any, **kwargs: Any) -> None: |
| 111 | + if original_init_weights is not None and callable(original_init_weights): |
| 112 | + original_init_weights(*args, **kwargs) |
| 113 | + for sub_module in module.modules(): |
| 114 | + if type(sub_module) in _lora_class_cache.values(): |
| 115 | + _init_weight = getattr(sub_module, "_init_weight", None) |
| 116 | + assert _init_weight is not None and callable(_init_weight) |
| 117 | + _init_weight() |
| 118 | + |
| 119 | + object.__setattr__(module, "init_weights", new_model_init_weights) |
| 120 | + |
| 121 | + def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None: |
| 122 | + pass |
| 123 | + |
| 124 | + |
| 125 | +def find_lora_config(converters: list) -> "LoRAConverter.Config | None": |
| 126 | + return next( |
| 127 | + (c for c in converters if isinstance(c, LoRAConverter.Config)), |
| 128 | + None, |
| 129 | + ) |
0 commit comments