Skip to content

SamsungLabs/GAWS

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 

Repository files navigation

GAWS: Grouped Adaptive Weight Sharing

Description

Adapter-based parameter-efficient fine-tuning enables multitask learning by attaching lightweight, task-specific adapters to a shared base model. However, efficiently serving multiple adapters poses deployment challenges. While merging adapters into the base model eliminates runtime overhead, it hinders model sharing across tasks, introduces potential numerical instability on quantized models, and complicates deployment in environments with static computational graphs. Conversely, serving unmerged adapters avoids these issues but comes at the cost of increased inference latency. Through analysis of LoRA adapters on GPUs, we attribute this latency primarily to segmented function calls. To address this, we propose Grouped Adaptive Weight Sharing (GAWS), a novel adapter design based on structured Kronecker product decomposition. Experiments on T5, GPT-2 Large, and LLaMA-3B show that GAWS reduces latency to about 42% of the gap between LoRA and the base model, while maintaining parameter efficiency and comparable accuracy. This positions GAWS as an effective solution for efficient multitask deployment.

Quickstart

Clone the repository and install the peft library:

git clone <insert_github_repository>
cd GAWS/peft
pip install -r requirements.txt
pip install .

Prepare a model for training with GAWS by wrapping the base model and GAWS configuration with get_peft_model.

from transformers import AutoModelForCausalLM
from peft import  GAWSConfig, PeftModel, get_peft_model # for GAWS-D use from peft import  GAWSDConfig
from peft.tuners.gaws import get_valid_split_dim_values # for GAWS-D use from peft.tuners.gaws_d import get_valid_splits_values

# load the base model
model_name_or_path = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)

# to get valid split dimension values
get_valid_split_dim_values(model, "q_proj") 

"""
    input_split_dim  output_split_dim  input_splits = output_splits
0              1280              1280                             2
1               640               640                             4
2               512               512                             5
3               320               320                             8
4               256               256                            10
5               160               160                            16
6               128               128                            20
7                80                80                            32
8                64                64                            40
9                40                40                            64
10               32                32                            80
11               20                20                           128
12               16                16                           160
13               10                10                           256
14                8                 8                           320
15                5                 5                           512
16                4                 4                           640
17                2                 2                          1280
"""


# construct adapter config
config = GAWSConfig(
    task_type="CAUSAL_LM",
target_modules=["v_proj", "q_proj", 'k_proj'],
    input_splits = 10,
    output_splits = 10,
    input_split_dim = 256,
    output_split_dim = 256,
    init_weights = "zero", # There are 3 available options 1) "zero" : zero initialization 2) "kaiming": kaiming initialization 3) "none": random initialization
    diag = False # Whether to add diagonal matrix to the model
    
)
# construct GAWS model
model = get_peft_model(model, config)
model.print_trainable_parameters()
"trainable params: 6,291,456 || all params: 2,785,975,296 || trainable%: 0.2258"

To save the GAWS model:

model.save_pretrained('phi2_gaws-s256')

To load the GAWS model for inference:

from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
import torch

model = AutoPeftModelForCausalLM.from_pretrained("phi2_gaws-s256").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")

model.eval()
inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors="pt")

outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=50)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])

"Preheat the oven to 350 degrees and place the cookie dough on a baking sheet. Bake for 10-12 minutes or until golden brown."

Contact

If you have any questions, please create an issue on this repository or contact at []

About

Parameter efficient fine-tuning technique that reduces the inference latency and maintains performance comparable to LoRA.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors