-
Notifications
You must be signed in to change notification settings - Fork 207
Support user-defined batch size for one shot #1117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
29f93d3
de38a64
afabe5a
a3a9f17
447b5ad
f87a78f
7b3f434
fe67d7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,17 @@ | ||
import logging | ||
import os | ||
from typing import Any, Callable, Dict, List, Optional | ||
|
||
import torch | ||
from datasets import Dataset, load_dataset | ||
from loguru import logger | ||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler | ||
from transformers.data import default_data_collator | ||
from transformers.data.data_collator import ( | ||
DataCollatorWithPadding, | ||
default_data_collator, | ||
) | ||
|
||
from llmcompressor.typing import Processor | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
LABELS_MASK_VALUE = -100 | ||
|
||
__all__ = [ | ||
|
@@ -21,8 +25,10 @@ | |
def format_calibration_data( | ||
tokenized_dataset: Dataset, | ||
num_calibration_samples: Optional[int] = None, | ||
batch_size: int = 1, | ||
do_shuffle: bool = True, | ||
collate_fn: Callable = default_data_collator, | ||
processor: Optional[Processor] = None, | ||
collate_fn: Optional[Callable] = None, | ||
accelerator: Optional[Any] = None, | ||
) -> List[torch.Tensor]: | ||
""" | ||
|
@@ -33,30 +39,47 @@ def format_calibration_data( | |
:param num_calibration_samples: number of data samples to convert | ||
:param do_shuffle: whether to shuffle the dataset before selecting calibration | ||
samples, true by default | ||
:param collate_fn: optional custom collate function, or use default | ||
:param collate_fn: optional custom collate function, defaults to | ||
`DataCollatorWithPadding` if None is provided. uses . If the tokenizer fails to | ||
resolve, then `default_data_collator` is used | ||
:param accelerator: optional accelerator for if preparing in FSDP mode | ||
:return: list of trimmed calibration data tensors | ||
""" | ||
# shuffle | ||
if do_shuffle: | ||
tokenized_dataset = tokenized_dataset.shuffle() | ||
|
||
# truncate samples | ||
safe_calibration_samples = len(tokenized_dataset) | ||
if num_calibration_samples is not None: | ||
safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) | ||
if safe_calibration_samples != num_calibration_samples: | ||
LOGGER.warn( | ||
logger.warning( | ||
f"Requested {num_calibration_samples} calibration samples but " | ||
f"the provided dataset only has {safe_calibration_samples}. " | ||
) | ||
|
||
if do_shuffle: | ||
tokenized_dataset = tokenized_dataset.shuffle() | ||
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) | ||
|
||
# collate data | ||
if collate_fn is None: | ||
tokenizer = getattr(processor, "tokenizer", processor) | ||
if hasattr(tokenizer, "pad"): | ||
collate_fn = DataCollatorWithPadding(tokenizer) | ||
else: | ||
logger.warning( | ||
"Could not find processor, attempting to collate with without padding " | ||
"(may fail for batch_size > 1)" | ||
) | ||
collate_fn = default_data_collator | ||
|
||
dataloader_params = { | ||
"batch_size": 1, | ||
"batch_size": batch_size, | ||
"sampler": RandomSampler(tokenized_calibration) | ||
if do_shuffle | ||
else SequentialSampler(tokenized_calibration), | ||
"collate_fn": collate_fn, | ||
"pin_memory": True, | ||
"drop_last": False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not drop if not divisible by batch size? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "drop_last" is not relevant if the number of samples is divisible by the batch size because there is no remainder. I'm confused what you're referring to here. |
||
} | ||
|
||
calib_dataloader = DataLoader(tokenized_calibration, **dataloader_params) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,8 +68,8 @@ def populate_datasets(self, processor: Processor, add_labels: bool = True): | |
:param processor: processor or tokenizer to use for dataset tokenization | ||
:param add_labels: if True, add labels column to dataset splits | ||
""" | ||
self.processor = processor # TODO: pass processor into init instead of this fn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this processor is the same as the model processor, it will be accessible by self.model_args.processor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At this location, We could change this flow in the future, but for now we separate |
||
if self._data_args.dataset is None: | ||
self.processor = self._model_args.processor | ||
logger.info( | ||
"Running oneshot without calibration data. This is expected for " | ||
"weight-only and dynamic quantization" | ||
|
@@ -144,8 +144,10 @@ def one_shot(self, stage: Optional[str] = None): | |
calib_data = format_calibration_data( | ||
tokenized_dataset=self.get_dataset_split("calibration"), | ||
num_calibration_samples=self._data_args.num_calibration_samples, | ||
batch_size=self._training_args.calibration_batch_size, | ||
do_shuffle=self._data_args.shuffle_calibration_samples, | ||
collate_fn=self._data_args.data_collator, | ||
processor=self.processor, | ||
accelerator=self.trainer.accelerator, | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,12 @@ class TrainingArguments(HFTrainingArgs): | |
) | ||
}, | ||
) | ||
calibration_batch_size: int = field( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oneshot will not depend on training_args in the follow up PR, so this will be moved once that lands. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which argument set should this exist on @horheynm? |
||
default=1, | ||
metadata={ | ||
"help": "The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for oneshot" | ||
}, | ||
) | ||
save_compressed: Optional[bool] = field( | ||
default=True, | ||
metadata={"help": "Whether to compress sparse models during save"}, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we not use this no more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok you moved it to
if hasattr(tokenizer, "pad"):
collate_fn = DataCollatorWithPadding(tokenizer)
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I moved this logic to
configure_processor
, as indicated in the PR description