Skip to content

Commit 8fc6012

Browse files
author
George
authored
[Training] Datasets - update Module (#1209)
Order of reviews: #1206 #1207 #1209 <-- Here #1212 #1214 SUMMARY: * Move dataset logic out of transformers module `src/llmcompressor/transformers/finetune/data/data_helpers.py`, add it to `src/llmcompressor/datasets/utils.py` TEST PLAN: Pass tests
1 parent 391b202 commit 8fc6012

File tree

8 files changed

+209
-191
lines changed

8 files changed

+209
-191
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# flake8: noqa
2+
3+
from .utils import (
4+
format_calibration_data,
5+
get_calibration_dataloader,
6+
get_processed_dataset,
7+
make_dataset_splits,
8+
)
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import re
2+
from typing import Any, Callable, Dict, List, Optional
3+
4+
import torch
5+
from datasets import Dataset
6+
from loguru import logger
7+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
8+
from transformers.data import default_data_collator
9+
10+
from llmcompressor.args import DatasetArguments
11+
from llmcompressor.transformers.finetune.data import TextGenerationDataset
12+
from llmcompressor.typing import Processor
13+
14+
15+
def get_processed_dataset(
16+
dataset_args: DatasetArguments,
17+
processor: Processor,
18+
do_oneshot: bool = False,
19+
do_train: bool = True,
20+
) -> Optional[Dict[str, Dataset]]:
21+
"""
22+
Loads datasets for each flow based on dataset_args, stores a Dataset for each
23+
enabled flow in datasets
24+
:param dataset_args: DatasetArguments that contain dataset loading and
25+
processing params
26+
:param processor: processor or tokenizer to use for dataset tokenization
27+
:param do_oneshot: True for oneshot pathway
28+
:param do_train: True for train pathway
29+
:return: A dataset corresponding to either train or calibration (oneshot)
30+
"""
31+
if dataset_args.dataset is None:
32+
logger.warning(
33+
"Running oneshot without calibration data. This is expected for "
34+
"weight-only and dynamic quantization"
35+
)
36+
return
37+
38+
splits = dataset_args.splits
39+
tokenized_datasets = {}
40+
41+
def _get_split_name(inp_str):
42+
# strip out split name, for ex train[60%:] -> train
43+
match = re.match(r"(\w*)\[.*\]", inp_str)
44+
if match is not None:
45+
return match.group(1)
46+
return inp_str
47+
48+
if splits is None:
49+
splits = {"all": None}
50+
elif isinstance(splits, str):
51+
splits = {_get_split_name(splits): splits}
52+
elif isinstance(splits, List):
53+
splits = {_get_split_name(s): s for s in splits}
54+
55+
# default to custom dataset if dataset provided isn't a string
56+
registry_id = (
57+
dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom"
58+
)
59+
for split_name, split_str in splits.items():
60+
dataset = dataset_args.dataset
61+
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
62+
# dataset is already tokenized
63+
tokenized_datasets[split_name] = dataset
64+
else:
65+
# dataset needs to be tokenized
66+
dataset_manager = TextGenerationDataset.load_from_registry(
67+
registry_id,
68+
dataset_args=dataset_args,
69+
split=split_str,
70+
processor=processor,
71+
)
72+
tokenized_datasets[split_name] = dataset_manager(add_labels=do_train)
73+
74+
return make_dataset_splits(
75+
tokenized_datasets,
76+
do_oneshot=do_oneshot,
77+
do_train=do_train,
78+
)
79+
80+
81+
def get_calibration_dataloader(
82+
dataset_args: DatasetArguments,
83+
processor: Processor,
84+
) -> torch.utils.data.DataLoader:
85+
"""
86+
Get the dataloader used for oneshot calibration.
87+
:param dataset_args: DatasetArguments that contains the dataset parameters.
88+
:param processor: Processor or the tokenizer of the model.
89+
:return: PyTorch dataloader object that contains the calibration dataset.
90+
"""
91+
if dataset_args.dataset is None:
92+
# weight-only quantization or dynamic quantization
93+
return
94+
95+
datasets = get_processed_dataset(
96+
dataset_args=dataset_args,
97+
processor=processor,
98+
do_oneshot=True,
99+
do_train=False,
100+
)
101+
102+
calibration_dataset = datasets.get("calibration")
103+
104+
return format_calibration_data(
105+
tokenized_dataset=calibration_dataset,
106+
num_calibration_samples=dataset_args.num_calibration_samples,
107+
do_shuffle=dataset_args.shuffle_calibration_samples,
108+
collate_fn=dataset_args.data_collator,
109+
)
110+
111+
112+
def format_calibration_data(
113+
tokenized_dataset: Dataset,
114+
num_calibration_samples: Optional[int] = None,
115+
do_shuffle: bool = True,
116+
collate_fn: Callable = default_data_collator,
117+
) -> List[torch.Tensor]:
118+
"""
119+
Creates a dataloader out of the calibration dataset split, trimming it to
120+
the desired number of calibration samples
121+
:param tokenized_dataset: dataset to convert to dataloader
122+
:param num_calibration_samples: number of data samples to convert
123+
:param do_shuffle: whether to shuffle the dataset before selecting calibration
124+
samples, true by default
125+
:param collate_fn: optional custom collate function, or use default
126+
:return: list of trimmed calibration data tensors
127+
"""
128+
safe_calibration_samples = len(tokenized_dataset)
129+
if num_calibration_samples is not None:
130+
safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples)
131+
if safe_calibration_samples != num_calibration_samples:
132+
logger.warn(
133+
f"Requested {num_calibration_samples} calibration samples but "
134+
f"the provided dataset only has {safe_calibration_samples}. "
135+
)
136+
137+
if do_shuffle:
138+
tokenized_dataset = tokenized_dataset.shuffle()
139+
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))
140+
141+
dataloader_params = {
142+
"batch_size": 1,
143+
"sampler": RandomSampler(tokenized_calibration)
144+
if do_shuffle
145+
else SequentialSampler(tokenized_calibration),
146+
"collate_fn": collate_fn,
147+
"pin_memory": True,
148+
}
149+
150+
calibration_dataloader = DataLoader(tokenized_calibration, **dataloader_params)
151+
152+
return calibration_dataloader
153+
154+
155+
def make_dataset_splits(
156+
tokenized_datasets: Dict[str, Any],
157+
do_oneshot: bool = True,
158+
do_train: bool = False,
159+
) -> Dict[str, Dataset]:
160+
"""
161+
Restructures the datasets dictionary based on what tasks will be run
162+
train
163+
:param tokenized_datasets: dictionary of processed datasets
164+
:param do_oneshot: Whether to store the calibration dataset
165+
:return: A dataset corresponding to either train or calibration (oneshot)
166+
"""
167+
168+
# handles case where all splits are contained in a single dataset
169+
if "all" in tokenized_datasets and len(tokenized_datasets) == 1:
170+
tokenized_datasets = tokenized_datasets.get("all")
171+
if isinstance(tokenized_datasets, Dataset):
172+
tokenized_datasets = {"train": tokenized_datasets}
173+
174+
train_split = calib_split = None
175+
176+
if do_train:
177+
if "train" not in tokenized_datasets:
178+
raise ValueError("--do_train requires a train dataset")
179+
train_split = tokenized_datasets["train"]
180+
if do_oneshot:
181+
calib_split = tokenized_datasets.get("calibration")
182+
if calib_split is None:
183+
if "train" not in tokenized_datasets:
184+
raise ValueError("--do_oneshot requires a calibration dataset")
185+
calib_split = tokenized_datasets["train"]
186+
187+
split_datasets = {
188+
"train": train_split,
189+
"calibration": calib_split,
190+
}
191+
return split_datasets

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
from llmcompressor.args import parse_args
99
from llmcompressor.core.session_functions import active_session
10-
from llmcompressor.transformers.finetune.data.data_helpers import (
11-
get_calibration_dataloader,
12-
)
10+
from llmcompressor.datasets import get_calibration_dataloader
1311
from llmcompressor.transformers.finetune.text_generation import (
1412
initialize_model_from_path,
1513
initialize_processor_from_path,

0 commit comments

Comments
 (0)