diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index fa8e434d4..c589658c5 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -53,11 +53,7 @@ def __init__( self.tokenizer = getattr(self.processor, "tokenizer", self.processor) if self.tokenizer is not None: - # fill in pad token - if not self.tokenizer.pad_token: - self.tokenizer.pad_token = self.tokenizer.eos_token - - # configure sequence length + # resolve sequence length max_seq_length = data_args.max_seq_length if data_args.max_seq_length > self.tokenizer.model_max_length: logger.warning( @@ -69,7 +65,7 @@ def __init__( data_args.max_seq_length, self.tokenizer.model_max_length ) - # configure padding + # resolve padding self.padding = ( False if self.data_args.concatenate_data diff --git a/src/llmcompressor/transformers/finetune/data/data_args.py b/src/llmcompressor/transformers/finetune/data/data_args.py index 7d0bc14ce..6154be4c2 100644 --- a/src/llmcompressor/transformers/finetune/data/data_args.py +++ b/src/llmcompressor/transformers/finetune/data/data_args.py @@ -1,8 +1,6 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Union -from transformers import DefaultDataCollator - @dataclass class DVCDatasetTrainingArguments: @@ -60,9 +58,12 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments): }, ) - data_collator: Callable[[Any], Any] = field( - default_factory=lambda: DefaultDataCollator(), - metadata={"help": "The function to used to form a batch from the dataset"}, + data_collator: Optional[Callable[[Any], Any]] = field( + default=None, + metadata={ + "help": "The function to used to form a batch from the dataset. Defaults " + "to `DataCollatorWithPadding` with model tokenizer if None is provided" + }, ) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 23c70e561..065ffc512 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,13 +1,17 @@ -import logging import os from typing import Any, Callable, Dict, List, Optional import torch from datasets import Dataset, load_dataset +from loguru import logger from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -from transformers.data import default_data_collator +from transformers.data.data_collator import ( + DataCollatorWithPadding, + default_data_collator, +) + +from llmcompressor.typing import Processor -LOGGER = logging.getLogger(__name__) LABELS_MASK_VALUE = -100 __all__ = [ @@ -21,8 +25,10 @@ def format_calibration_data( tokenized_dataset: Dataset, num_calibration_samples: Optional[int] = None, + batch_size: int = 1, do_shuffle: bool = True, - collate_fn: Callable = default_data_collator, + processor: Optional[Processor] = None, + collate_fn: Optional[Callable] = None, accelerator: Optional[Any] = None, ) -> List[torch.Tensor]: """ @@ -33,30 +39,47 @@ def format_calibration_data( :param num_calibration_samples: number of data samples to convert :param do_shuffle: whether to shuffle the dataset before selecting calibration samples, true by default - :param collate_fn: optional custom collate function, or use default + :param collate_fn: optional custom collate function, defaults to + `DataCollatorWithPadding` if None is provided. uses . If the tokenizer fails to + resolve, then `default_data_collator` is used :param accelerator: optional accelerator for if preparing in FSDP mode :return: list of trimmed calibration data tensors """ + # shuffle + if do_shuffle: + tokenized_dataset = tokenized_dataset.shuffle() + + # truncate samples safe_calibration_samples = len(tokenized_dataset) if num_calibration_samples is not None: safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) if safe_calibration_samples != num_calibration_samples: - LOGGER.warn( + logger.warning( f"Requested {num_calibration_samples} calibration samples but " f"the provided dataset only has {safe_calibration_samples}. " ) - - if do_shuffle: - tokenized_dataset = tokenized_dataset.shuffle() tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) + # collate data + if collate_fn is None: + tokenizer = getattr(processor, "tokenizer", processor) + if hasattr(tokenizer, "pad"): + collate_fn = DataCollatorWithPadding(tokenizer) + else: + logger.warning( + "Could not find processor, attempting to collate with without padding " + "(may fail for batch_size > 1)" + ) + collate_fn = default_data_collator + dataloader_params = { - "batch_size": 1, + "batch_size": batch_size, "sampler": RandomSampler(tokenized_calibration) if do_shuffle else SequentialSampler(tokenized_calibration), "collate_fn": collate_fn, "pin_memory": True, + "drop_last": False, } calib_dataloader = DataLoader(tokenized_calibration, **dataloader_params) diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 0a07c45eb..ed315e76b 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -68,8 +68,8 @@ def populate_datasets(self, processor: Processor, add_labels: bool = True): :param processor: processor or tokenizer to use for dataset tokenization :param add_labels: if True, add labels column to dataset splits """ + self.processor = processor # TODO: pass processor into init instead of this fn if self._data_args.dataset is None: - self.processor = self._model_args.processor logger.info( "Running oneshot without calibration data. This is expected for " "weight-only and dynamic quantization" @@ -144,8 +144,10 @@ def one_shot(self, stage: Optional[str] = None): calib_data = format_calibration_data( tokenized_dataset=self.get_dataset_split("calibration"), num_calibration_samples=self._data_args.num_calibration_samples, + batch_size=self._training_args.calibration_batch_size, do_shuffle=self._data_args.shuffle_calibration_samples, collate_fn=self._data_args.data_collator, + processor=self.processor, accelerator=self.trainer.accelerator, ) diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 61e6441bb..99c960e06 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -20,6 +20,7 @@ import os import warnings from pathlib import PosixPath +from types import NoneType from loguru import logger from transformers import ( @@ -286,6 +287,27 @@ def initialize_processor_from_path( return processor +def configure_processor(processor: Processor): + # configure tokenizer pad_token, required for padding and data collation + tokenizer = getattr(processor, "tokenizer", processor) + if getattr(tokenizer, "pad_token", None) is None: + if hasattr(tokenizer, "eos_token"): + logger.debug("Tokenizer is missing pad_token, using eos_token instead") + tokenizer.pad_token = tokenizer.eos_token + else: + logger.debug( + "Tokenizer is missing pad_token and eos_token, this may lead to issues " + " when padding" + ) + + # the chat template attribute is required for saving, patch some processors which do + # no include this attribute (phi3_v) + processor_ct = getattr(processor, "chat_template", None) + tokenizer_ct = getattr(tokenizer, "chat_template", None) + if processor_ct is None and tokenizer_ct is not None: + processor.chat_template = tokenizer.chat_template + + def main( model_args: ModelArguments, data_args: DataTrainingArguments, @@ -361,8 +383,9 @@ def main( teacher.eval() processor = model_args.processor - if isinstance(processor, str) or processor is None: + if isinstance(processor, (str, NoneType)): processor = initialize_processor_from_path(model_args, model, teacher) + configure_processor(processor) pre_initialize_structure(model=model) diff --git a/src/llmcompressor/transformers/finetune/training_args.py b/src/llmcompressor/transformers/finetune/training_args.py index c04fa2807..75239aff3 100644 --- a/src/llmcompressor/transformers/finetune/training_args.py +++ b/src/llmcompressor/transformers/finetune/training_args.py @@ -32,6 +32,12 @@ class TrainingArguments(HFTrainingArgs): ) }, ) + calibration_batch_size: int = field( + default=1, + metadata={ + "help": "The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for oneshot" + }, + ) save_compressed: Optional[bool] = field( default=True, metadata={"help": "Whether to compress sparse models during save"}, diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 812b26a56..98e98c00f 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -1,10 +1,16 @@ +# TODO: rename to `test_data_helpers.py` import pytest +import torch +from datasets import Dataset +from transformers import AutoTokenizer from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( + format_calibration_data, get_raw_dataset, make_dataset_splits, ) +from llmcompressor.transformers.finetune.text_generation import configure_processor @pytest.mark.unit @@ -53,3 +59,45 @@ def test_separate_datasets(): split_datasets = make_dataset_splits( datasets, do_train=True, do_eval=True, do_predict=True ) + + +@pytest.mark.unit +def test_format_calibration_data_padded_tokenized(): + vocab_size = 512 + seq_len = 2048 + ds_size = 16 + padded_tokenized_dataset = Dataset.from_dict( + {"input_ids": torch.randint(0, vocab_size, (ds_size, seq_len))} + ) + + calibration_dataloader = format_calibration_data( + padded_tokenized_dataset, num_calibration_samples=8, batch_size=4 + ) + + batch = next(iter(calibration_dataloader)) + assert batch["input_ids"].size(0) == 4 + + +@pytest.mark.unit +def test_format_calibration_data_unpaddded_tokenized(): + vocab_size = 512 + ds_size = 16 + unpadded_tokenized_dataset = Dataset.from_dict( + { + "input_ids": [ + torch.randint(0, vocab_size, (seq_len,)) for seq_len in range(ds_size) + ] + } + ) + processor = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") + configure_processor(processor) + + calibration_dataloader = format_calibration_data( + unpadded_tokenized_dataset, + num_calibration_samples=8, + batch_size=4, + processor=processor, + ) + + batch = next(iter(calibration_dataloader)) + assert batch["input_ids"].size(0) == 2