Skip to content

Lora fine-tuning using PEFT Library#204

Open
sidhantls wants to merge 2 commits intohuggingface:mainfrom
sidhantls:peft_v2
Open

Lora fine-tuning using PEFT Library#204
sidhantls wants to merge 2 commits intohuggingface:mainfrom
sidhantls:peft_v2

Conversation

@sidhantls
Copy link
Copy Markdown

@sidhantls sidhantls commented Feb 14, 2025

Adds a LORA implementation for parameter efficient fine-tuning of Parler TTS

Address #183 #158 and other request

Feature

This PR adds PEFT support with Low-Rank adapters (LORA) for fine-tuning Parler-TTS on new datasets.

LORA is applied to the Parler-TTS decoder Transformer where PEFT is applied to Linear projection layers. Fine-tuning with lora trains only 0.5% of parameters for Parler Mini

Benefits

  • PEFT enables Fine-tuning of Parler-Mini on 8GB GPU (with limited offloading?).
    • On my windows machine, it takes 29.52s/it with PEFT and 117.50s/it without PEFT for fine-tuning.
  • PEFT enables Fine-Tuning of Parler-TTS Large, 2.3B on Google Collab. Without, I get OOM error.

An alternative implementation of PR #159, which enables training with lora, loading checkpoints and final LORA model. Moreover, it uses the "peft" library, rather than #159, which was a custom implementation. Moreover, this PR allows loading saved checkpoints, which was not possible in #159

How to use:

Fine-Tuning:

When running accelerate launch ./training/run_parler_tts_training.py for fine-tuning, use --use_lora true --lora_r 8 --lora_alpha 16 --lora_dropout 0.05

Loading Checkpoints:

from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
from peft import PeftModel
import torch

device = "cuda"

# Base model (pretrained)
base_model_name = "parler-tts/parler_tts_mini_v0.1"
peft_model_path = "output_dir_training/"  # PEFT model path/checkpoint path

torch_dtype = torch.float16

# Load base model
base_model = ParlerTTSForConditionalGeneration.from_pretrained(base_model_name).to(device, dtype=torch_dtype)

# Load PEFT model on top of the base model
peft_model = PeftModel.from_pretrained(base_model, peft_model_path)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

# merge LORA adapters into model, 50% faster inference,
model = peft_model.merge_and_unload()

## regular inference code for model
from IPython.display import Audio

prompt = "It was a bright cold day in April."
description = "Jenny speaks with a monotone voice, in a very close-sounding environment with almost no noise. She speaks slightly fast."

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# generate audio 
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids).to(torch.float32)
audio_arr = generation.cpu().numpy().squeeze()

Audio(data=audio_arr, rate=44000)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant