Skip to content

Commit 7ba58e4

Browse files
quic-mamtamamtsing
andauthored
[QEff Finetune] : fix task_type variable in configs (quic#514)
1. fix task_type variable in configs 2. enabled passing peft_config yaml/json file from command line. 3. updated run_ft_model.py --------- Signed-off-by: Mamta Singh <mamtsing@qti.qualcomm.com> Co-authored-by: Mamta Singh <mamtsing@qti.qualcomm.com>
1 parent a78e983 commit 7ba58e4

File tree

7 files changed

+105
-71
lines changed

7 files changed

+105
-71
lines changed

QEfficient/cloud/finetune.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import random
1010
import warnings
11-
from typing import Any, Dict, Optional, Union
11+
from typing import Any, Optional, Union
1212

1313
import numpy as np
1414
import torch
@@ -27,6 +27,7 @@
2727
update_config,
2828
)
2929
from QEfficient.finetune.utils.dataset_utils import get_dataloader
30+
from QEfficient.finetune.utils.helper import Task_Mode
3031
from QEfficient.finetune.utils.logging_utils import logger
3132
from QEfficient.finetune.utils.parser import get_finetune_parser
3233
from QEfficient.finetune.utils.train_utils import (
@@ -90,14 +91,13 @@ def setup_seeds(seed: int) -> None:
9091

9192

9293
def load_model_and_tokenizer(
93-
train_config: TrainConfig, dataset_config: Any, peft_config_file: str, **kwargs
94+
train_config: TrainConfig, dataset_config: Any, **kwargs
9495
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
9596
"""Load the pre-trained model and tokenizer from Hugging Face.
9697
9798
Args:
9899
config (TrainConfig): Training configuration object containing model and tokenizer names.
99100
dataset_config (Any): A dataclass object representing dataset configuration.
100-
peft_config_file (str): Path to PEFT config file used for PEFT finetuning.
101101
kwargs: Additional arguments to override PEFT config.
102102
103103
Returns:
@@ -113,7 +113,7 @@ def load_model_and_tokenizer(
113113
"""
114114
logger.log_rank_zero(f"Loading HuggingFace model for {train_config.model_name}")
115115
pretrained_model_path = hf_download(train_config.model_name)
116-
if train_config.task_type == "seq_classification":
116+
if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION:
117117
model = AutoModelForSequenceClassification.from_pretrained(
118118
pretrained_model_path,
119119
num_labels=dataset_config.num_labels,
@@ -166,21 +166,17 @@ def load_model_and_tokenizer(
166166
"Given model doesn't support gradient checkpointing. Please disable it and run it.", RuntimeError
167167
)
168168

169-
model = apply_peft(model, train_config, peft_config_file, **kwargs)
169+
model = apply_peft(model, train_config, **kwargs)
170170

171171
return model, tokenizer
172172

173173

174-
def apply_peft(
175-
model: AutoModel, train_config: TrainConfig, peft_config_file: Dict, **kwargs
176-
) -> Union[AutoModel, PeftModel]:
174+
def apply_peft(model: AutoModel, train_config: TrainConfig, **kwargs) -> Union[AutoModel, PeftModel]:
177175
"""Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled.
178176
179177
Args:
180178
model (AutoModel): Huggingface model.
181179
train_config (TrainConfig): Training configuration object.
182-
peft_config_file (str, optional): Path to YAML/JSON file containing
183-
PEFT (LoRA) config. Defaults to None.
184180
kwargs: Additional arguments to override PEFT config params.
185181
186182
Returns:
@@ -197,7 +193,7 @@ def apply_peft(
197193
peft_config = model.peft_config
198194
# Generate the peft config and start fine-tuning from original model
199195
else:
200-
peft_config = generate_peft_config(train_config, peft_config_file, **kwargs)
196+
peft_config = generate_peft_config(train_config, **kwargs)
201197
model = get_peft_model(model, peft_config)
202198
print_trainable_parameters(model)
203199

@@ -254,12 +250,11 @@ def setup_dataloaders(
254250
return train_dataloader, eval_dataloader, longest_seq_length
255251

256252

257-
def main(peft_config_file: str = None, **kwargs) -> None:
253+
def main(**kwargs) -> None:
258254
"""
259255
Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
260256
261257
Args:
262-
peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config. Defaults to None.
263258
kwargs: Additional arguments to override TrainConfig.
264259
265260
Example:
@@ -286,7 +281,7 @@ def main(peft_config_file: str = None, **kwargs) -> None:
286281

287282
setup_distributed_training(train_config)
288283
setup_seeds(train_config.seed)
289-
model, tokenizer = load_model_and_tokenizer(train_config, dataset_config, peft_config_file, **kwargs)
284+
model, tokenizer = load_model_and_tokenizer(train_config, dataset_config, **kwargs)
290285

291286
# Create DataLoaders for the training and validation dataset
292287
train_dataloader, eval_dataloader, longest_seq_length = setup_dataloaders(train_config, dataset_config, tokenizer)
@@ -295,7 +290,6 @@ def main(peft_config_file: str = None, **kwargs) -> None:
295290
f"passed context length is {train_config.context_length} and overall model's context length is "
296291
f"{model.config.max_position_embeddings}"
297292
)
298-
299293
model.to(train_config.device)
300294
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
301295
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)

QEfficient/finetune/configs/training.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import logging
99
from dataclasses import dataclass
1010

11+
from QEfficient.finetune.utils.helper import Batching_Strategy, Device, Peft_Method, Task_Mode
12+
1113

1214
# Configuration Classes
1315
@dataclass
@@ -35,10 +37,11 @@ class TrainConfig:
3537
gamma (float): Learning rate decay factor (default: 0.85).
3638
seed (int): Random seed for reproducibility (default: 42).
3739
dataset (str): Dataset name for training (default: "alpaca_dataset").
38-
task_type (str): Type of task for which the finetuning is to be done. Options: "generation" and "seq_classification". (default: "generation")
40+
task_mode (str): Mode of task for which the finetuning is to be done. Options: "generation" and "seq_classification". (default: "generation")
3941
use_peft (bool): Whether to use PEFT (default: True).
4042
peft_method (str): Parameter-efficient fine-tuning method (default: "lora").
41-
from_peft_checkpoint (str): Path to PEFT checkpoint (default: "").
43+
peft_config_file (str): Path to YAML/JSON file containing PEFT (LoRA) config. (default: None)
44+
from_peft_checkpoint (str): Path to PEFT checkpoint (default: None).
4245
output_dir (str): Directory to save outputs (default: "training_results").
4346
save_model (bool): Save the trained model (default: True).
4447
save_metrics (bool): Save training metrics (default: True).
@@ -49,8 +52,9 @@ class TrainConfig:
4952
convergence_loss (float): Loss threshold for convergence (default: 1e-4).
5053
use_profiler (bool): Enable profiling (default: False).
5154
enable_ddp (bool): Enable distributed data parallel (default: False).
52-
dump_root_dir (str): Directory for mismatch dumps (default: "mismatches/step_").
5355
opByOpVerifier (bool): Enable operation-by-operation verification (default: False).
56+
dump_logs (bool): Whether to dump logs (default: True).
57+
log_level (str): logging level (default: logging.INFO)
5458
"""
5559

5660
model_name: str = "meta-llama/Llama-3.2-1B"
@@ -66,22 +70,23 @@ class TrainConfig:
6670
num_epochs: int = 1
6771
max_train_step: int = 0
6872
max_eval_step: int = 0
69-
device: str = "qaic"
73+
device: str = Device.QAIC.value
7074
num_workers_dataloader: int = 1
7175
lr: float = 3e-4
7276
weight_decay: float = 0.0
7377
gamma: float = 0.85 # multiplicatively decay the learning rate by gamma after each epoch
7478
seed: int = 42
7579
dataset: str = "alpaca_dataset"
76-
task_type: str = "generation" # "generation" / "seq_classification"
80+
task_mode: str = Task_Mode.GENERATION.value # "generation" / "seq_classification"
7781
use_peft: bool = True # use parameter efficient finetuning
78-
peft_method: str = "lora"
79-
from_peft_checkpoint: str = "" # if not empty and peft_method='lora', will load the peft checkpoint and resume the fine-tuning on that checkpoint
82+
peft_method: str = Peft_Method.LORA.value
83+
peft_config_file: str = None
84+
from_peft_checkpoint: str = None # if not empty and peft_method='lora', will load the peft checkpoint and resume the fine-tuning on that checkpoint
8085
output_dir: str = "training_results"
8186
save_model: bool = True
8287
save_metrics: bool = True # saves training metrics to a json file for later plotting
8388
intermediate_step_save: int = 1000
84-
batching_strategy: str = "packing"
89+
batching_strategy: str = Batching_Strategy.PADDING.value
8590
enable_ddp: bool = False
8691
enable_sorting_for_ddp: bool = True
8792
convergence_counter: int = 5 # its value should be >= 1, stop fine tuning when loss <= convergence_loss (defined below) for #convergence_counter steps

QEfficient/finetune/utils/config_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from QEfficient.finetune.configs.peft_config import LoraConfig
1919
from QEfficient.finetune.configs.training import TrainConfig
2020
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
21+
from QEfficient.finetune.utils.helper import Peft_Method
2122
from QEfficient.finetune.utils.logging_utils import logger
2223

2324

@@ -52,25 +53,24 @@ def update_config(config, **kwargs):
5253
logger.debug(f"Unknown parameter '{k}' for config type '{config_type}'")
5354

5455

55-
def generate_peft_config(train_config: TrainConfig, peft_config_file: str = None, **kwargs) -> Any:
56+
def generate_peft_config(train_config: TrainConfig, **kwargs) -> Any:
5657
"""Generate a PEFT-compatible configuration from a custom config based on peft_method.
5758
5859
Args:
5960
train_config (TrainConfig): Training configuration with peft_method.
60-
custom_config: Custom configuration object (e.g., LoraConfig).
6161
6262
Returns:
6363
Any: A PEFT-specific configuration object (e.g., PeftLoraConfig).
6464
6565
Raises:
6666
RuntimeError: If the peft_method is not supported.
6767
"""
68-
if peft_config_file:
69-
peft_config_data = load_config_file(peft_config_file)
70-
validate_config(peft_config_data, config_type="lora")
68+
if train_config.peft_config_file:
69+
peft_config_data = load_config_file(train_config.peft_config_file)
70+
validate_config(peft_config_data, config_type=Peft_Method.LORA)
7171
peft_config = PeftLoraConfig(**peft_config_data)
7272
else:
73-
config_map = {"lora": (LoraConfig, PeftLoraConfig)}
73+
config_map = {Peft_Method.LORA: (LoraConfig, PeftLoraConfig)}
7474
if train_config.peft_method not in config_map:
7575
logger.raise_error(f"Peft config not found: {train_config.peft_method}", RuntimeError)
7676

@@ -105,7 +105,7 @@ def generate_dataset_config(dataset_name: str) -> Any:
105105
return dataset_config
106106

107107

108-
def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> None:
108+
def validate_config(config_data: Dict[str, Any], config_type: str = Peft_Method.LORA) -> None:
109109
"""Validate the provided YAML/JSON configuration for required fields and types.
110110
111111
Args:
@@ -120,7 +120,7 @@ def validate_config(config_data: Dict[str, Any], config_type: str = "lora") -> N
120120
- Validates required fields for LoraConfig: r, lora_alpha, target_modules.
121121
- Ensures types match expected values (int, float, list, etc.).
122122
"""
123-
if config_type.lower() != "lora":
123+
if config_type.lower() != Peft_Method.LORA:
124124
logger.raise_error(f"Unsupported config_type: {config_type}. Only 'lora' is supported.", ValueError)
125125

126126
required_fields = {

QEfficient/finetune/utils/helper.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# -----------------------------------------------------------------------------
77
import os
88
from contextlib import nullcontext
9+
from enum import Enum
910

1011
import torch
1112

@@ -15,10 +16,28 @@
1516
print(f"Warning: {e}. Moving ahead without these qaic modules.")
1617

1718

18-
TASK_TYPE = ["generation", "seq_classification"]
19-
PEFT_METHOD = ["lora"]
20-
DEVICE = ["qaic", "cpu", "cuda"]
21-
BATCHING_STRATEGY = ["padding", "packing"]
19+
class Batching_Strategy(str, Enum):
20+
PADDING = "padding"
21+
PACKING = "packing"
22+
23+
24+
class Device(str, Enum):
25+
QAIC = "qaic"
26+
CPU = "cpu"
27+
CUDA = "cuda"
28+
29+
30+
class Peft_Method(str, Enum):
31+
LORA = "lora"
32+
33+
34+
class Task_Mode(str, Enum):
35+
GENERATION = "generation"
36+
SEQ_CLASSIFICATION = "seq_classification"
37+
38+
39+
def enum_names(enum_cls):
40+
return [member.value for member in enum_cls]
2241

2342

2443
def is_rank_zero():

QEfficient/finetune/utils/parser.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
# -----------------------------------------------------------------------------
77

88
import argparse
9+
import logging
910

1011
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
11-
from QEfficient.finetune.utils.helper import BATCHING_STRATEGY, DEVICE, PEFT_METHOD, TASK_TYPE
12+
from QEfficient.finetune.utils.helper import Batching_Strategy, Device, Peft_Method, Task_Mode, enum_names
1213

1314

1415
def str2bool(v):
@@ -110,7 +111,14 @@ def get_finetune_parser():
110111
default=0,
111112
help="Maximum evaluation steps, unlimited if 0",
112113
)
113-
parser.add_argument("--device", required=False, type=str, default="qaic", choices=DEVICE, help="Device to train on")
114+
parser.add_argument(
115+
"--device",
116+
required=False,
117+
type=str,
118+
default=Device.QAIC.value,
119+
choices=enum_names(Device),
120+
help="Device to train on",
121+
)
114122
parser.add_argument(
115123
"--num_workers_dataloader",
116124
"--num-workers-dataloader",
@@ -140,12 +148,12 @@ def get_finetune_parser():
140148
help="Dataset name to be used for finetuning (default: %(default)s)",
141149
)
142150
parser.add_argument(
143-
"--task_type",
144-
"--task-type",
151+
"--task_mode",
152+
"--task-mode",
145153
required=False,
146154
type=str,
147-
default="generation",
148-
choices=TASK_TYPE,
155+
default=Task_Mode.GENERATION.value,
156+
choices=enum_names(Task_Mode),
149157
help="Task used for finetuning. Use 'generation' for decoder based models and 'seq_classification' for encoder based models.",
150158
)
151159
parser.add_argument(
@@ -162,8 +170,8 @@ def get_finetune_parser():
162170
"--peft-method",
163171
required=False,
164172
type=str,
165-
default="lora",
166-
choices=PEFT_METHOD,
173+
default=Peft_Method.LORA.value,
174+
choices=enum_names(Peft_Method),
167175
help="Parameter efficient finetuning technique to be used. Currently only 'lora' is supported.",
168176
)
169177
parser.add_argument(
@@ -213,8 +221,8 @@ def get_finetune_parser():
213221
"--batching-strategy",
214222
required=False,
215223
type=str,
216-
default="padding",
217-
choices=BATCHING_STRATEGY,
224+
default=Batching_Strategy.PADDING.value,
225+
choices=enum_names(Batching_Strategy),
218226
help="Strategy for making batches of data points. Packing groups data points into batches by minimizing unnecessary empty spaces. Padding adds extra values (often zeros) to batch sequences so they align in size. Currently only padding is supported which is by default.",
219227
)
220228
parser.add_argument(
@@ -261,7 +269,22 @@ def get_finetune_parser():
261269
# This is for debugging purpose only.
262270
# Enables operation-by-operation verification w.r.t reference device(cpu).
263271
# It is a context manager interface that captures and verifies each operator against reference device.
264-
# In case results of test & reference do not match under given tolerances, a standalone unittest is generated at dump_root_dir.
272+
# In case results of test & reference do not match under given tolerances, a standalone unittest is generated at output_dir/mismatches.
273+
)
274+
parser.add_argument(
275+
"--log_level",
276+
"--log-level",
277+
required=False,
278+
type=str,
279+
default=logging.INFO,
280+
help="logging level",
281+
)
282+
parser.add_argument(
283+
"--peft_config_file",
284+
"--peft-config-file",
285+
type=str,
286+
default=None,
287+
help="Path to YAML/JSON file containing PEFT (LoRA) config.",
265288
)
266289

267290
return parser

0 commit comments

Comments
 (0)