-
Notifications
You must be signed in to change notification settings - Fork 757
[1/N] Add lora adapters for finetune #2484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/mori360/1/base
Are you sure you want to change the base?
Changes from all commits
f55fb81
c829153
5e595c2
dcc6962
13c6133
e1199ad
be974e0
c55b15f
5a362c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import math | ||
| from dataclasses import dataclass | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from torchtitan.config import Configurable | ||
| from torchtitan.models.common.linear import Linear | ||
| from torchtitan.tools.logging import logger | ||
|
|
||
| # Cache for dynamically created LoRA classes | ||
| _lora_class_cache: dict[type, type] = {} | ||
|
|
||
|
|
||
| def apply_lora(linear: nn.Linear, rank: int, alpha: float) -> nn.Linear: | ||
| parent_cls = type(linear) | ||
| assert issubclass( | ||
| parent_cls, nn.Linear | ||
| ), f"parent_cls must be a subclass of nn.Linear, got {parent_cls}" | ||
|
|
||
| if parent_cls not in _lora_class_cache: | ||
|
|
||
| class LoRALinear(parent_cls): # type: ignore[valid-type, misc] | ||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
| raise RuntimeError("LoRALinear should not be instantiated directly.") | ||
|
|
||
| @classmethod | ||
| def from_linear( | ||
| cls, linear: nn.Linear, rank: int, alpha: float | ||
| ) -> "LoRALinear": | ||
| linear.__class__ = cls | ||
| linear._init_lora(rank, alpha) # type: ignore[attr-defined] | ||
| return linear # type: ignore[return-value] | ||
|
|
||
| def _init_lora( | ||
| self, | ||
| rank: int, | ||
| alpha: float, | ||
| device: torch.device | None = None, | ||
| dtype: torch.dtype | None = None, | ||
| ) -> None: | ||
| self._lora_scaling = alpha / rank | ||
| device = device if device is not None else self.weight.device | ||
| dtype = dtype if dtype is not None else self.weight.dtype | ||
| self.lora_a = ( | ||
| Linear.Config(bias=False) | ||
| .build(in_features=self.in_features, out_features=rank) | ||
| .to(device=device, dtype=dtype) | ||
| ) | ||
| self.lora_b = ( | ||
| Linear.Config(bias=False) | ||
| .build(in_features=rank, out_features=self.out_features) | ||
| .to(device=device, dtype=dtype) | ||
| ) | ||
|
|
||
| def init_weights(self, **kwargs) -> None: | ||
| super().init_weights(**kwargs) # pyrefly: ignore [not-callable] | ||
| nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) | ||
| nn.init.zeros_(self.lora_b.weight) | ||
|
|
||
| def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
| base_out = super().forward(input) | ||
| lora_out = self.lora_b(self.lora_a(input)) | ||
| return base_out + self._lora_scaling * lora_out | ||
|
|
||
| LoRALinear.__name__ = f"LoRA{parent_cls.__name__}" | ||
| LoRALinear.__qualname__ = f"LoRA{parent_cls.__name__}" | ||
| _lora_class_cache[parent_cls] = LoRALinear | ||
|
|
||
| # pyrefly: ignore [missing-attribute] | ||
| return _lora_class_cache[parent_cls].from_linear(linear, rank, alpha) | ||
|
|
||
|
|
||
| class LoRAConverter(Configurable): | ||
| """Apply LoRA adapters to all Linear layers in a model.""" | ||
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class Config(Configurable.Config): | ||
| rank: int = 8 | ||
| """Rank of the LoRA matrices (lora_a: in_features x rank, lora_b: rank x out_features).""" | ||
|
|
||
| alpha: float = 16.0 | ||
| """Scaling factor. Output is scaled by alpha/rank.""" | ||
|
|
||
| def __init__(self, config: Config, **kwargs): | ||
| self.rank = config.rank | ||
| self.alpha = config.alpha | ||
| logger.info(f"LoRA training active with rank={self.rank}, alpha={self.alpha}") | ||
|
|
||
| def convert(self, model: nn.Module) -> None: | ||
| model.requires_grad_(False) | ||
| self._replace_linears_with_lora(model) | ||
|
|
||
| def _replace_linears_with_lora(self, module: nn.Module) -> None: | ||
| for _, child in list(module.named_modules()): | ||
| if isinstance(child, nn.Linear): | ||
| apply_lora(child, self.rank, self.alpha) | ||
|
|
||
| def post_optimizer_hook(self, model: nn.Module | list[nn.Module]) -> None: | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from torchtitan.components.checkpoint import CheckpointManager | ||
| from torchtitan.components.lora import LoRAConverter | ||
| from torchtitan.components.lr_scheduler import LRSchedulersContainer | ||
| from torchtitan.components.metrics import MetricsProcessor | ||
| from torchtitan.components.optimizer import ( | ||
|
|
@@ -108,6 +109,19 @@ def llama3_debugmodel_float8_emulate() -> Trainer.Config: | |
| return config | ||
|
|
||
|
|
||
| def llama3_debugmodel_lora() -> Trainer.Config: | ||
| config = llama3_debugmodel() | ||
| config.model_converters = ModelConvertersContainer.Config( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Among all the converters (quantization, lora, etc.), order matters. For now let's create a Note that it can't capture the wrong order, if we first init the main config and then modify a field. @fegin Validation is becoming an issue. If we don't go with chz, maybe we should at least freeze the config, and require users to use dataclass.replace to trigger post init checks. |
||
| converters=[ | ||
| LoRAConverter.Config( | ||
| rank=8, | ||
| alpha=16.0, | ||
| ), | ||
| ], | ||
| ) | ||
| return config | ||
|
|
||
|
|
||
| def llama3_8b() -> Trainer.Config: | ||
| return Trainer.Config( | ||
| hf_assets_path="./assets/hf/Llama-3.1-8B", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove new_model_init_weights after refactoring Linear @fegin