Skip to content

Commit 546927d

Browse files
Add Orthogonal Subspace Fine-Tuning (OSF) Tuner for Parameter-Efficient Continual Learning (huggingface#2685)
This adds a new parameter-efficient fine-tuning method called **Orthogonal Subspace Fine-Tuning (OSF)** to the PEFT library. OSF enables continual learning in LLMs by freezing the high-rank subspace of weight matrices and fine-tuning only the low-rank directions. This approach constrains updates to be orthogonal to previously important directions, thereby mitigating catastrophic forgetting without increasing parameter count. Tracked in [PEFT Issue huggingface#2648](huggingface#2648) **Notes** * The current implementation does not include layerwise importance-based rank estimation (e.g., cosine similarity of inputs and activations), but can be added in future iterations * Unmerging is not supported, as the original weights are decomposed and modified in-place * Compared to LoRA, OSF performs a constrained update over the original weight matrix without introducing new trainable parameters, maintaining exact model architecture post-training **Background** This implementation is based on the method described in our paper: Sculpting Subspaces: Constrained Full Fine-Tuning in LLMs for Continual Learning
1 parent a82ca6d commit 546927d

File tree

20 files changed

+1183
-49
lines changed

20 files changed

+1183
-49
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@
9090
title: LoKr
9191
- local: package_reference/lora
9292
title: LoRA
93+
- local: package_reference/osf
94+
title: OSF
9395
- local: package_reference/xlora
9496
title: X-LoRA
9597
- local: package_reference/adapter_utils
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# OSF (Orthogonal Subspace Fine-tuning)
18+
19+
Orthogonal Subspace Fine-tuning ([OSF](https://huggingface.co/papers/2504.07097)) is a PEFT method designed for continual learning that constrains parameter updates to be orthogonal to previously important directions. This approach enables full fine-tuning while preventing catastrophic forgetting without requiring additional parameters or storing previous gradients.
20+
21+
The abstract from the paper is:
22+
23+
*Continual learning in large language models (LLMs) is prone to catastrophic forgetting, where adapting to new tasks significantly degrades performance on previously learned ones. Existing methods typically rely on low-rank, parameter-efficient updates that limit the model's expressivity and introduce additional parameters per task, leading to scalability issues. To address these limitations, we propose a novel continual full fine-tuning approach leveraging adaptive singular value decomposition (SVD). Our method dynamically identifies task-specific low-rank parameter subspaces and constrains updates to be orthogonal to critical directions associated with prior tasks, thus effectively minimizing interference without additional parameter overhead or storing previous task gradients. We evaluate our approach extensively on standard continual learning benchmarks using both encoder-decoder (T5-Large) and decoder-only (LLaMA-2 7B) models, spanning diverse tasks including classification, generation, and reasoning. Empirically, our method achieves state-of-the-art results, up to 7% higher average accuracy than recent baselines like O-LoRA, and notably maintains the model's general linguistic capabilities, instruction-following accuracy, and safety throughout the continual learning process by reducing forgetting to near-negligible levels. Our adaptive SVD framework effectively balances model plasticity and knowledge retention, providing a practical, theoretically grounded, and computationally scalable solution for continual learning scenarios in large language models.*
24+
25+
## How OSF Works
26+
27+
OSF decomposes each weight matrix into high-rank (frozen) and low-rank (trainable) components using SVD:
28+
29+
```
30+
W = U_high * S_high * V_high^T + U_low * S_low * V_low^T
31+
```
32+
33+
Where:
34+
- `U_high, S_high, V_high`: Preserve important directions from previous tasks (frozen)
35+
- `U_low, S_low, V_low`: Allow adaptation to new tasks (trainable)
36+
37+
During training, gradients are projected to be orthogonal to the high-rank subspace, ensuring updates don't interfere with previously learned knowledge.
38+
39+
## Basic Usage
40+
41+
```python
42+
import torch
43+
from transformers import AutoModelForCausalLM, AutoTokenizer
44+
from peft import OSFConfig, get_peft_model
45+
46+
# Load base model
47+
model = AutoModelForCausalLM.from_pretrained("gpt2")
48+
49+
# Configure OSF
50+
config = OSFConfig(
51+
target_modules=["c_attn", "c_proj"], # Target attention layers
52+
effective_rank=8, # Default rank for decomposition
53+
rank_pattern={"c_attn": 16} # Override rank for specific modules
54+
)
55+
56+
# Apply OSF
57+
model = get_peft_model(model, config)
58+
59+
# Train as usual
60+
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
61+
62+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
63+
tokenizer.pad_token = tokenizer.eos_token
64+
65+
inputs = tokenizer("Hello world", return_tensors="pt", padding=True)
66+
loss = model(**inputs, labels=inputs.input_ids).loss
67+
loss.backward()
68+
optimizer.step()
69+
optimizer.zero_grad()
70+
```
71+
72+
## Configuration Options
73+
74+
### Target Modules
75+
76+
You can specify target modules in several ways:
77+
78+
```python
79+
# Specific module names
80+
config = OSFConfig(target_modules=["q_proj", "k_proj", "v_proj", "o_proj"])
81+
82+
# All linear layers
83+
config = OSFConfig(target_modules="all-linear")
84+
85+
# Model-specific defaults (automatically detected)
86+
config = OSFConfig() # Uses model-appropriate defaults
87+
```
88+
89+
### Effective Rank Configuration
90+
91+
Control the preserved/trainable subspaces:
92+
93+
```python
94+
# Global preserved rank (applies to all target modules)
95+
config = OSFConfig(effective_rank=16) # preserves top-16 singular directions; trains the rest
96+
97+
# Automatic preserved rank (50% of the smaller matrix dimension per target)
98+
config = OSFConfig(effective_rank=None)
99+
100+
# Per-module preserved-rank overrides
101+
config = OSFConfig(
102+
effective_rank=8,
103+
rank_pattern={
104+
"q_proj": 16, # Higher rank for query projection
105+
"gate_proj": 4 # Lower rank for gate projection
106+
}
107+
)
108+
109+
# Fractional preserved rank is supported (interpreted per-target as fraction * min_dim)
110+
config = OSFConfig(effective_rank=0.8) # preserve 80% of min_dim; train remaining 20%
111+
config = OSFConfig(rank_pattern={"q_proj": 0.5}) # preserve 50% on q_proj, others use global/default
112+
```
113+
114+
Note: OSF's `effective_rank` is the preserved (frozen) rank, not the trainable rank. The trainable rank equals `min(weight.shape) - effective_rank`. This differs from LoRA's `r`, which directly specifies the trainable rank.
115+
116+
117+
## Training Advice for Continual Learning
118+
119+
### Sequential Task Learning
120+
121+
OSF is specifically designed for learning tasks sequentially. Between tasks, recompute the SVD so the preserved subspace reflects the latest weights. One simple way is to re-wrap the updated base model with OSF again:
122+
123+
```python
124+
# Task 1: train on domain A with initial preserved subspace
125+
r = 8 # initial effective rank to preserve
126+
model = get_peft_model(base_model, OSFConfig(effective_rank=r))
127+
train_task(model, task_1_data)
128+
129+
# Task 2: recompute SVD on updated weights and increase preserved subspace
130+
base_model = model.unload() # unwrap base model without assuming internals
131+
r += 4 # grow preserved subspace to include Task 1 knowledge
132+
model = get_peft_model(base_model, OSFConfig(effective_rank=r))
133+
train_task(model, task_2_data)
134+
135+
# Task 3: recompute again and expand preserved subspace further
136+
base_model = model.unload()
137+
r += 4
138+
model = get_peft_model(base_model, OSFConfig(effective_rank=r))
139+
train_task(model, task_3_data)
140+
```
141+
142+
### Budget Allocation for Task Sequences
143+
144+
When training on a known sequence of n tasks, one effective strategy is to progressively allocate model capacity to balance learning new tasks while preserving previous knowledge:
145+
146+
- **Task 1**: Use full capacity (train everything)
147+
- **Task 2**: Freeze 1/n of model capacity, train remaining (n-1)/n capacity
148+
- **Task 3**: Freeze 2/n of model capacity, train remaining (n-2)/n capacity
149+
- **Task n**: Freeze (n-1)/n of model capacity, use 1/n capacity for final task
150+
151+
This approach ensures each task gets adequate learning capacity while progressively preserving more knowledge from previous tasks.
152+
153+
```python
154+
# Example: 4-task sequence with progressive budget allocation
155+
n_tasks = 4
156+
max_preserved_rank = 512 # Upper bound for preserved rank per target (heuristic)
157+
158+
for task_id in range(n_tasks):
159+
# Freeze increases over time; trainable capacity shrinks
160+
preserved_fraction = (task_id + 1) / n_tasks
161+
preserved_rank = int(max_preserved_rank * preserved_fraction)
162+
163+
config = OSFConfig(
164+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
165+
effective_rank=preserved_rank,
166+
)
167+
168+
print(
169+
f"Task {task_id + 1}: Preserving rank {preserved_rank} "
170+
f"({preserved_fraction:.1%} of max_preserved_rank - {max_preserved_rank} frozen); trainable rank = min_dim - preserved_rank"
171+
)
172+
173+
model = get_peft_model(base_model, config)
174+
train_task(model, task_data[task_id])
175+
```
176+
177+
### Best Practices
178+
179+
1. **Effective Rank Selection**: Start with `effective_rank=None` (auto sets rank to 50% of the smaller weight dimension per target module) and adjust based on task complexity
180+
2. **Learning Rate**: Use smaller learning rates (1e-5 to 1e-4) compared to standard fine-tuning
181+
3. **Task Importance**: Use `rank_pattern` to allocate more capacity to critical modules
182+
4. **Model Architecture**: OSF works best with transformer architectures having clear attention and MLP separations
183+
5. **Capacity Planning**: For known task sequences, use progressive budget allocation (1/n, 2/n, ..., (n-1)/n freezing) to balance plasticity and stability
184+
185+
### Memory Considerations
186+
187+
OSF modifies weights in-place and doesn't add parameters, making it memory-efficient:
188+
189+
```python
190+
# Memory usage remains close to base model
191+
print(f"Base model parameters: {base_model.num_parameters():,}")
192+
print(f"OSF model parameters: {osf_model.num_parameters():,}") # Similar count
193+
```
194+
195+
## Advanced Usage
196+
197+
### Custom Target Modules
198+
199+
For models with non-standard architectures:
200+
201+
```python
202+
config = OSFConfig(
203+
target_modules=["dense", "intermediate.dense"], # Custom layer names
204+
effective_rank=12,
205+
rank_pattern={"dense": 8, "intermediate.dense": 16}
206+
)
207+
```
208+
209+
### Integration with Other Methods
210+
211+
OSF can be combined with other techniques:
212+
213+
```python
214+
# Use with gradient checkpointing for memory efficiency
215+
model.gradient_checkpointing_enable()
216+
217+
# Apply weight decay selectively (regularizes low-rank factors to limit drift/overfitting in continual updates; keep small)
218+
optimizer = torch.optim.AdamW([
219+
{"params": [p for n, p in model.named_parameters() if "U_low" in n], "weight_decay": 0.01},
220+
{"params": [p for n, p in model.named_parameters() if "S_low" in n], "weight_decay": 0.001},
221+
{"params": [p for n, p in model.named_parameters() if "V_low" in n], "weight_decay": 0.01},
222+
], lr=1e-4)
223+
```
224+
225+
## OSFConfig
226+
227+
[[autodoc]] tuners.osf.config.OSFConfig
228+
229+
## OSFModel
230+
231+
[[autodoc]] tuners.osf.model.OSFModel
232+
233+
## Utility Functions
234+
235+
### Weight Decomposition
236+
237+
[[autodoc]] tuners.osf.utils.decompose_weight_matrix
238+
239+
[[autodoc]] tuners.osf.utils.reconstruct_weight_matrix
240+
241+
### Gradient Projection
242+
243+
[[autodoc]] tuners.osf.utils.project_gradient_to_orthogonal_space
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Orthogonal Subspace Learning with Adaptive OSF
2+
3+
## TODO: Runnable Example Needed
4+
5+
This folder is a placeholder for a comprehensive OSF example. As suggested in the review feedback:
6+
7+
> "If you can, provide a runnable example in this folder instead, you can take a look at the EVA example for inspiration. A runnable example can be a good place to showcase the different features. Jupyter notebooks are fine as well."
8+
9+
### Planned Example Features:
10+
- Complete continual learning scenario with multiple tasks
11+
- Demonstration of OSF's catastrophic forgetting prevention
12+
- Configuration examples (target_modules, effective_rank, rank_pattern)
13+
- Performance comparison with baseline methods
14+
- Memory usage analysis
15+
16+
### Current Basic Usage:
17+
For basic usage examples and API documentation, see the [OSF documentation](../../docs/source/package_reference/osf.md).
18+
19+
```python
20+
import torch
21+
from transformers import AutoModelForCausalLM, AutoTokenizer
22+
from peft import OSFConfig, get_peft_model
23+
24+
model = AutoModelForCausalLM.from_pretrained("gpt2")
25+
config = OSFConfig(target_modules=["c_attn", "c_proj"], effective_rank=8)
26+
model = get_peft_model(model, config)
27+
28+
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
29+
30+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
31+
tokenizer.pad_token = tokenizer.eos_token
32+
inputs = tokenizer("Hello world", return_tensors="pt", padding=True)
33+
loss = model(**inputs, labels=inputs.input_ids).loss
34+
loss.backward()
35+
optimizer.step()
36+
optimizer.zero_grad()
37+
```
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
"task_type": null,
3+
"peft_type": "OSF",
4+
"auto_mapping": null,
5+
"base_model_name_or_path": "meta-llama/Llama-3.2-3B",
6+
"revision": null,
7+
"inference_mode": false,
8+
"effective_rank": null,
9+
"target_modules": [
10+
"q_proj",
11+
"k_proj",
12+
"v_proj",
13+
"o_proj",
14+
"gate_proj",
15+
"down_proj",
16+
"up_proj"
17+
],
18+
"rank_pattern": {
19+
"q_proj": 2944,
20+
"o_proj": 2944,
21+
"k_proj": 896,
22+
"v_proj": 896,
23+
"gate_proj": 2944,
24+
"down_proj": 2944,
25+
"up_proj": 2944
26+
}
27+
}
28+
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"optimizer_kwargs": {
3+
"lr": 5e-5
4+
}
5+
}
6+

src/peft/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@
8484
MultitaskPromptTuningInit,
8585
OFTConfig,
8686
OFTModel,
87+
OSFConfig,
88+
OSFModel,
8789
PolyConfig,
8890
PolyModel,
8991
PrefixEncoder,
@@ -181,6 +183,8 @@
181183
"MultitaskPromptTuningInit",
182184
"OFTConfig",
183185
"OFTModel",
186+
"OSFConfig",
187+
"OSFModel",
184188
"PeftConfig",
185189
"PeftMixedModel",
186190
"PeftModel",

src/peft/tuners/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .mixed import MixedModel
4141
from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit
4242
from .oft import OFTConfig, OFTModel
43+
from .osf import OSFConfig, OSFModel
4344
from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType
4445
from .poly import PolyConfig, PolyModel
4546
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
@@ -95,6 +96,8 @@
9596
"MultitaskPromptTuningInit",
9697
"OFTConfig",
9798
"OFTModel",
99+
"OSFConfig",
100+
"OSFModel",
98101
"PolyConfig",
99102
"PolyModel",
100103
"PrefixEncoder",

src/peft/tuners/osf/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from peft.utils import register_peft_method
2+
3+
from .config import OSFConfig
4+
from .layer import Linear, OSFLayer
5+
from .model import OSFModel
6+
7+
8+
__all__ = ["Linear", "OSFConfig", "OSFLayer", "OSFModel"]
9+
10+
register_peft_method(
11+
name="osf",
12+
config_cls=OSFConfig,
13+
model_cls=OSFModel,
14+
is_mixed_compatible=False,
15+
)

0 commit comments

Comments
 (0)