-
Notifications
You must be signed in to change notification settings - Fork 757
[1/N] Add lora adapters for finetune #2484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/mori360/1/base
Are you sure you want to change the base?
Changes from 4 commits
f55fb81
c829153
5e595c2
dcc6962
13c6133
e1199ad
be974e0
c55b15f
5a362c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import math | ||
| from dataclasses import dataclass | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from torchtitan.config import Configurable | ||
| from torchtitan.tools.logging import logger | ||
|
|
||
| # Cache for dynamically created LoRA classes | ||
| _lora_class_cache: dict[type, type] = {} | ||
|
|
||
|
|
||
| def apply_lora(linear: nn.Linear, rank: int, alpha: float) -> nn.Linear: | ||
| parent_cls = type(linear) | ||
| assert issubclass( | ||
| parent_cls, nn.Linear | ||
| ), f"parent_cls must be a subclass of nn.Linear, got {parent_cls}" | ||
|
|
||
| if parent_cls not in _lora_class_cache: | ||
|
|
||
| class LoRALinear(parent_cls): # type: ignore[valid-type, misc] | ||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
| raise RuntimeError("LoRALinear should not be instantiated directly.") | ||
|
|
||
| @classmethod | ||
| def from_linear( | ||
| cls, linear: nn.Linear, rank: int, alpha: float | ||
| ) -> "LoRALinear": | ||
| linear.__class__ = cls | ||
| linear._init_lora(rank, alpha) # type: ignore[attr-defined] | ||
| return linear # type: ignore[return-value] | ||
|
|
||
| def _init_lora( | ||
| self, | ||
| rank: int, | ||
| alpha: float, | ||
| device: torch.device | None = None, | ||
| dtype: torch.dtype | None = None, | ||
| ) -> None: | ||
| self._lora_scaling = alpha / rank | ||
| device = device if device is not None else self.weight.device | ||
| dtype = dtype if dtype is not None else self.weight.dtype | ||
| self.lora_a = nn.Linear( | ||
| self.in_features, | ||
| rank, | ||
| bias=False, | ||
| device=device, | ||
| dtype=dtype, | ||
| ) | ||
| self.lora_b = nn.Linear( | ||
| rank, | ||
| self.out_features, | ||
| bias=False, | ||
| device=device, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
| def _init_weight(self) -> None: | ||
| nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) | ||
| nn.init.zeros_(self.lora_b.weight) | ||
|
|
||
| def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
| base_out = super().forward(input) | ||
| lora_out = self.lora_b(self.lora_a(input)) | ||
| return base_out + self._lora_scaling * lora_out | ||
|
|
||
| LoRALinear.__name__ = f"LoRA{parent_cls.__name__}" | ||
| LoRALinear.__qualname__ = f"LoRA{parent_cls.__name__}" | ||
| _lora_class_cache[parent_cls] = LoRALinear | ||
|
|
||
| # pyrefly: ignore [missing-attribute] | ||
| return _lora_class_cache[parent_cls].from_linear(linear, rank, alpha) | ||
|
|
||
|
|
||
| class LoRAConverter(Configurable): | ||
| """Apply LoRA adapters to all Linear layers in a model.""" | ||
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class Config(Configurable.Config): | ||
| rank: int = 8 | ||
| """Rank of the LoRA matrices (lora_a: in_features x rank, lora_b: rank x out_features).""" | ||
|
|
||
| alpha: float = 16.0 | ||
| """Scaling factor. Output is scaled by alpha/rank.""" | ||
|
|
||
| def __init__(self, config: Config, **kwargs): | ||
| self.rank = config.rank | ||
| self.alpha = config.alpha | ||
| logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}") | ||
|
|
||
| def convert(self, model: nn.Module) -> None: | ||
| model.requires_grad_(False) | ||
| self._replace_linears_with_lora(model) | ||
|
|
||
| def _replace_linears_with_lora(self, module: nn.Module) -> None: | ||
| for _, child in list(module.named_modules()): | ||
| if isinstance(child, nn.Linear): | ||
| apply_lora(child, self.rank, self.alpha) | ||
|
|
||
| # Patch init_weights to also reinitialize LoRA adapters | ||
| original_init_weights = getattr(module, "init_weights", None) | ||
|
|
||
| def new_model_init_weights(*args: Any, **kwargs: Any) -> None: | ||
| if original_init_weights is not None and callable(original_init_weights): | ||
| original_init_weights(*args, **kwargs) | ||
| for sub_module in module.modules(): | ||
| if type(sub_module) in _lora_class_cache.values(): | ||
| _init_weight = getattr(sub_module, "_init_weight", None) | ||
| assert _init_weight is not None and callable(_init_weight) | ||
| _init_weight() | ||
|
|
||
| object.__setattr__(module, "init_weights", new_model_init_weights) | ||
|
|
||
| def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None: | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from torchtitan.components.checkpoint import CheckpointManager | ||
| from torchtitan.components.lora import LoRAConverter | ||
| from torchtitan.components.lr_scheduler import LRSchedulersContainer | ||
| from torchtitan.components.metrics import MetricsProcessor | ||
| from torchtitan.components.optimizer import ( | ||
|
|
@@ -108,6 +109,19 @@ def llama3_debugmodel_float8_emulate() -> Trainer.Config: | |
| return config | ||
|
|
||
|
|
||
| def llama3_debugmodel_lora() -> Trainer.Config: | ||
| config = llama3_debugmodel() | ||
| config.model_converters = ModelConvertersContainer.Config( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Among all the converters (quantization, lora, etc.), order matters. For now let's create a Note that it can't capture the wrong order, if we first init the main config and then modify a field. @fegin Validation is becoming an issue. If we don't go with chz, maybe we should at least freeze the config, and require users to use dataclass.replace to trigger post init checks. |
||
| converters=[ | ||
| LoRAConverter.Config( | ||
| rank=8, | ||
| alpha=16.0, | ||
| ), | ||
| ], | ||
| ) | ||
| return config | ||
|
|
||
|
|
||
| def llama3_8b() -> Trainer.Config: | ||
| return Trainer.Config( | ||
| hf_assets_path="./assets/hf/Llama-3.1-8B", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we need this because nn.Linear doesn't have
init_weightsand instead in torchtitan we directly operating on the weight and bias in any nn.Linear submodule.I believe @fegin will refactor this to create our own Linear module with init_weights capability. So let's leave a TODO here to remove this
new_model_init_weights.