Skip to content

Commit 29f93d3

Browse files
committed
WIP
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 999d660 commit 29f93d3

File tree

5 files changed

+60
-10
lines changed

5 files changed

+60
-10
lines changed

src/llmcompressor/transformers/finetune/data/data_args.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from dataclasses import dataclass, field
22
from typing import Any, Callable, Dict, List, Optional, Union
33

4-
from transformers import DefaultDataCollator
5-
64

75
@dataclass
86
class DVCDatasetTrainingArguments:
@@ -60,9 +58,12 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments):
6058
},
6159
)
6260

63-
data_collator: Callable[[Any], Any] = field(
64-
default_factory=lambda: DefaultDataCollator(),
65-
metadata={"help": "The function to used to form a batch from the dataset"},
61+
data_collator: Optional[Callable[[Any], Any]] = field(
62+
default=None,
63+
metadata={
64+
"help": "The function to used to form a batch from the dataset. Defaults "
65+
"to `DataCollatorWithPadding` with model tokenizer if None is provided"
66+
},
6667
)
6768

6869

src/llmcompressor/transformers/finetune/data/data_helpers.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import logging
22
import os
3+
import warnings
34
from typing import Any, Callable, Dict, List, Optional
45

56
import torch
67
from datasets import Dataset, load_dataset
78
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
8-
from transformers.data import default_data_collator
9+
from transformers.data.data_collator import (
10+
DataCollatorWithPadding,
11+
default_data_collator,
12+
)
13+
14+
from llmcompressor.typing import Processor
915

1016
LOGGER = logging.getLogger(__name__)
1117
LABELS_MASK_VALUE = -100
@@ -21,7 +27,9 @@
2127
def format_calibration_data(
2228
tokenized_dataset: Dataset,
2329
num_calibration_samples: Optional[int] = None,
30+
batch_size: int = 1,
2431
do_shuffle: bool = True,
32+
processor: Optional[Processor] = None,
2533
collate_fn: Callable = default_data_collator,
2634
accelerator: Optional[Any] = None,
2735
) -> List[torch.Tensor]:
@@ -37,6 +45,11 @@ def format_calibration_data(
3745
:param accelerator: optional accelerator for if preparing in FSDP mode
3846
:return: list of trimmed calibration data tensors
3947
"""
48+
# shuffle
49+
if do_shuffle:
50+
tokenized_dataset = tokenized_dataset.shuffle()
51+
52+
# truncate samples
4053
safe_calibration_samples = len(tokenized_dataset)
4154
if num_calibration_samples is not None:
4255
safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples)
@@ -45,13 +58,22 @@ def format_calibration_data(
4558
f"Requested {num_calibration_samples} calibration samples but "
4659
f"the provided dataset only has {safe_calibration_samples}. "
4760
)
48-
49-
if do_shuffle:
50-
tokenized_dataset = tokenized_dataset.shuffle()
5161
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))
5262

63+
# collate data
64+
if collate_fn is None:
65+
tokenizer = getattr(processor, "tokenizer", processor)
66+
if tokenizer is None:
67+
warnings.warn(
68+
"Could not find processor, attempting to collate with without padding "
69+
"(may fail for batch_size > 1)"
70+
)
71+
return default_data_collator()
72+
73+
collate_fn = DataCollatorWithPadding(tokenizer)
74+
5375
dataloader_params = {
54-
"batch_size": 1,
76+
"batch_size": batch_size,
5577
"sampler": RandomSampler(tokenized_calibration)
5678
if do_shuffle
5779
else SequentialSampler(tokenized_calibration),

src/llmcompressor/transformers/finetune/runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,10 @@ def one_shot(self, stage: Optional[str] = None):
144144
calib_data = format_calibration_data(
145145
tokenized_dataset=self.get_dataset_split("calibration"),
146146
num_calibration_samples=self._data_args.num_calibration_samples,
147+
batch_size=self._training_args.per_device_oneshot_batch_size,
147148
do_shuffle=self._data_args.shuffle_calibration_samples,
148149
collate_fn=self._data_args.data_collator,
150+
processor=self.processor,
149151
accelerator=self.trainer.accelerator,
150152
)
151153

src/llmcompressor/transformers/finetune/training_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ class TrainingArguments(HFTrainingArgs):
3232
)
3333
},
3434
)
35+
per_device_oneshot_batch_size: int = field(
36+
default=1,
37+
metadata={
38+
"help": "The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for oneshot"
39+
},
40+
)
3541
save_compressed: Optional[bool] = field(
3642
default=True,
3743
metadata={"help": "Whether to compress sparse models during save"},

tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
# TODO: rename to `test_data_helpers.py`
12
import pytest
3+
import torch
4+
from datasets import Dataset
25

36
from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
47
from llmcompressor.transformers.finetune.data.data_helpers import (
8+
format_calibration_data,
59
get_raw_dataset,
610
make_dataset_splits,
711
)
@@ -53,3 +57,18 @@ def test_separate_datasets():
5357
split_datasets = make_dataset_splits(
5458
datasets, do_train=True, do_eval=True, do_predict=True
5559
)
60+
61+
62+
@pytest.mark.unit
63+
def test_format_calibration_data():
64+
tokenized_dataset = Dataset.from_dict(
65+
{"input_ids": torch.randint(0, 512, (8, 2048))}
66+
)
67+
68+
calibration_dataloader = format_calibration_data(
69+
tokenized_dataset, num_calibration_samples=4, batch_size=2
70+
)
71+
72+
batch = next(iter(calibration_dataloader))
73+
74+
assert batch["input_ids"].size(0) == 2

0 commit comments

Comments
 (0)