Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/llmcompressor/transformers/finetune/data/data_args.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Union

from transformers import DefaultDataCollator


@dataclass
class DVCDatasetTrainingArguments:
Expand Down Expand Up @@ -60,9 +58,12 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments):
},
)

data_collator: Callable[[Any], Any] = field(
default_factory=lambda: DefaultDataCollator(),
metadata={"help": "The function to used to form a batch from the dataset"},
data_collator: Optional[Callable[[Any], Any]] = field(
default=None,
metadata={
"help": "The function to used to form a batch from the dataset. Defaults "
"to `DataCollatorWithPadding` with model tokenizer if None is provided"
},
)


Expand Down
32 changes: 27 additions & 5 deletions src/llmcompressor/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import logging
import os
import warnings
from typing import Any, Callable, Dict, List, Optional

import torch
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.data import default_data_collator
from transformers.data.data_collator import (
DataCollatorWithPadding,
default_data_collator,
)

from llmcompressor.typing import Processor

LOGGER = logging.getLogger(__name__)
LABELS_MASK_VALUE = -100
Expand All @@ -21,7 +27,9 @@
def format_calibration_data(
tokenized_dataset: Dataset,
num_calibration_samples: Optional[int] = None,
batch_size: int = 1,
do_shuffle: bool = True,
processor: Optional[Processor] = None,
collate_fn: Callable = default_data_collator,
accelerator: Optional[Any] = None,
) -> List[torch.Tensor]:
Expand All @@ -37,6 +45,11 @@ def format_calibration_data(
:param accelerator: optional accelerator for if preparing in FSDP mode
:return: list of trimmed calibration data tensors
"""
# shuffle
if do_shuffle:
tokenized_dataset = tokenized_dataset.shuffle()

# truncate samples
safe_calibration_samples = len(tokenized_dataset)
if num_calibration_samples is not None:
safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples)
Expand All @@ -45,13 +58,22 @@ def format_calibration_data(
f"Requested {num_calibration_samples} calibration samples but "
f"the provided dataset only has {safe_calibration_samples}. "
)

if do_shuffle:
tokenized_dataset = tokenized_dataset.shuffle()
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))

# collate data
if collate_fn is None:
tokenizer = getattr(processor, "tokenizer", processor)
if tokenizer is None:
warnings.warn(
"Could not find processor, attempting to collate with without padding "
"(may fail for batch_size > 1)"
)
return default_data_collator()

collate_fn = DataCollatorWithPadding(tokenizer)

dataloader_params = {
"batch_size": 1,
"batch_size": batch_size,
"sampler": RandomSampler(tokenized_calibration)
if do_shuffle
else SequentialSampler(tokenized_calibration),
Expand Down
2 changes: 2 additions & 0 deletions src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,10 @@ def one_shot(self, stage: Optional[str] = None):
calib_data = format_calibration_data(
tokenized_dataset=self.get_dataset_split("calibration"),
num_calibration_samples=self._data_args.num_calibration_samples,
batch_size=self._training_args.per_device_oneshot_batch_size,
do_shuffle=self._data_args.shuffle_calibration_samples,
collate_fn=self._data_args.data_collator,
processor=self.processor,
accelerator=self.trainer.accelerator,
)

Expand Down
6 changes: 6 additions & 0 deletions src/llmcompressor/transformers/finetune/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ class TrainingArguments(HFTrainingArgs):
)
},
)
per_device_oneshot_batch_size: int = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why per device? Considering gptq's sequential nature/one active execution device

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is just to match the existing per_device_train_batch_size argument name. We can alias this or resolve per_device_train_batch_size = per_device_oneshot_batch_size = batch_size in the future

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that oneshot is unlikely to support device-parallel computation in the future, I'm fine using a more concise name now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to oneshot_batch_size to allow users to have separate batch sizes for oneshot and train. We can add a batch_size later if we think that's a better interface

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to calibration_batch_size

default=1,
metadata={
"help": "The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for oneshot"
},
)
save_compressed: Optional[bool] = field(
default=True,
metadata={"help": "Whether to compress sparse models during save"},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# TODO: rename to `test_data_helpers.py`
import pytest
import torch
from datasets import Dataset

from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
format_calibration_data,
get_raw_dataset,
make_dataset_splits,
)
Expand Down Expand Up @@ -53,3 +57,18 @@ def test_separate_datasets():
split_datasets = make_dataset_splits(
datasets, do_train=True, do_eval=True, do_predict=True
)


@pytest.mark.unit
def test_format_calibration_data():
tokenized_dataset = Dataset.from_dict(
{"input_ids": torch.randint(0, 512, (8, 2048))}
)

calibration_dataloader = format_calibration_data(
tokenized_dataset, num_calibration_samples=4, batch_size=2
)

batch = next(iter(calibration_dataloader))

assert batch["input_ids"].size(0) == 2