-
Notifications
You must be signed in to change notification settings - Fork 37
Improve sample packing #347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 14 commits
14b90bd
ff821ba
e47c58c
e20671e
61fb86e
5cc68ec
41c56e7
12a4fbf
4b8f8f6
9945e6a
b0a2c65
e5b8e4d
982daaa
f1c13e7
fa91599
cf1b04a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| from typing import Literal | ||
| from typing import Optional | ||
| from typing import Tuple | ||
| from typing import Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
@@ -38,6 +39,15 @@ | |
| from arctic_training.data.utils import DatasetType | ||
|
|
||
| IGNORE_INDEX = -100 | ||
| PACKING_KEYS = ( | ||
| "input_ids", | ||
| "labels", | ||
| "position_ids", | ||
| "packed_sample_seqlens", | ||
| "attention_mask", | ||
| ) | ||
|
|
||
| Packed_Data_Type = Dict[str, List[Union[List[int], int]]] | ||
|
|
||
|
|
||
| # this function is modified from TRL trl.trainer.utils.py | ||
|
|
@@ -142,9 +152,11 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]: | |
| if "position_ids" in instances[0]: | ||
| position_ids = [torch.tensor(example["position_ids"]) for example in instances] | ||
| packed_sample_seqlens = [example["packed_sample_seqlens"] for example in instances] | ||
| packed_seqlens_square_sum = [example["packed_seqlens_square_sum"] for example in instances] | ||
| else: | ||
| position_ids = [torch.tensor(list(range(len(example["input_ids"])))) for example in instances] | ||
| packed_sample_seqlens = [[len(example["input_ids"])] for example in instances] | ||
| packed_seqlens_square_sum = [-1 for example in instances] | ||
|
|
||
| fake_unpacked_long_seq = False | ||
| # fake_unpacked_long_seq = True | ||
|
|
@@ -178,21 +190,128 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]: | |
| "labels": labels, | ||
| "position_ids": position_ids, | ||
| "packed_sample_seqlens": packed_sample_seqlens, | ||
| "packed_seqlens_square_sum": packed_seqlens_square_sum, | ||
| } | ||
|
|
||
|
|
||
| def pack_sft_batch( | ||
| def sort_packed_sft_batch(batch: Packed_Data_Type, reverse: bool) -> Packed_Data_Type: | ||
| packed_list = [] | ||
| packed_keys = list(batch.keys()) | ||
|
|
||
| for idx in range(len(batch["input_ids"])): | ||
| packed_dict = {key: batch[key][idx] for key in packed_keys} | ||
| packed_list.append(packed_dict) | ||
|
|
||
| def sum_square_compare(packed_sample): | ||
| return sum([seqlen**2 for seqlen in packed_sample["packed_sample_seqlens"]]) | ||
|
|
||
| packed_list.sort(key=sum_square_compare, reverse=reverse) | ||
|
|
||
| packed_batch: Packed_Data_Type = {k: [] for k in packed_keys} | ||
|
|
||
| for packed_sample in packed_list: | ||
| for key in packed_keys: | ||
| packed_batch[key].append(packed_sample[key]) | ||
|
|
||
| return packed_batch | ||
|
|
||
|
|
||
| def pack_sft_batch_balance_length( | ||
| batch: Dict[str, List[List[int]]], | ||
| max_length: int, | ||
| always_max_length: bool, | ||
| drop_last: bool, | ||
| fuse_positions_prob: float, | ||
| seed: int, | ||
| ) -> Dict[str, List[List[int]]]: | ||
| keys = ("input_ids", "labels", "position_ids", "packed_sample_seqlens", "attention_mask") | ||
| packed_batch: Dict[str, List[List[int]]] = {k: [] for k in keys} | ||
| ) -> Packed_Data_Type: | ||
| keys = PACKING_KEYS | ||
| packed_batch: Packed_Data_Type = {k: [] for k in keys} | ||
| packed_batch["packed_seqlens_square_sum"] = [] | ||
|
|
||
| rng = random.Random(seed) | ||
|
|
||
| # Best-fit-decreasing bin packing to maximize utilization of `max_length`. | ||
| # This packs multiple short samples within the provided batch into larger samples each trying to be as close as possible to max_length. | ||
| samples = list(zip(batch["input_ids"], batch["labels"], batch["attention_mask"])) | ||
| # Sort by length descending; tie-breaker is deterministic to keep runs reproducible. | ||
| sorted_indices = sorted(range(len(samples)), key=lambda i: len(samples[i][0]), reverse=True) | ||
|
|
||
| bins: List[Dict[str, List[int]]] = [] | ||
| bin_lengths: List[int] = [] | ||
|
|
||
| def start_new_bin() -> int: | ||
| bins.append({k: [] for k in keys}) | ||
| bin_lengths.append(0) | ||
| return len(bins) - 1 | ||
|
|
||
| for idx in sorted_indices: | ||
| input_ids, labels, attention_mask = samples[idx] | ||
| sample_len = len(input_ids) | ||
|
|
||
| # Find the bin that leaves the least remaining space after insertion. | ||
| best_bin = None | ||
| best_remaining = None | ||
| for bin_idx, current_len in enumerate(bin_lengths): | ||
| remaining = max_length - current_len | ||
| if remaining <= 0: | ||
| continue | ||
| if not always_max_length and sample_len > remaining: | ||
| continue | ||
| take_len = min(sample_len, remaining) | ||
| remaining_after = remaining - take_len | ||
| if best_remaining is None or remaining_after < best_remaining: | ||
| best_remaining = remaining_after | ||
| best_bin = bin_idx | ||
|
|
||
| if best_bin is None: | ||
| best_bin = start_new_bin() | ||
|
|
||
| target_bin = bins[best_bin] | ||
| remaining = max_length - bin_lengths[best_bin] | ||
| if remaining <= 0: | ||
| continue # should not happen, but guard against negative remaining | ||
| take_len = min(sample_len, remaining) if always_max_length else sample_len | ||
| take_len = min(take_len, remaining) | ||
|
|
||
| target_bin["input_ids"].extend(input_ids[:take_len]) | ||
| target_bin["labels"].extend(labels[:take_len]) | ||
| target_bin["attention_mask"].extend(attention_mask[:take_len]) | ||
| target_bin["position_ids"].extend(range(take_len)) | ||
| target_bin["packed_sample_seqlens"].append(take_len) | ||
| bin_lengths[best_bin] += take_len | ||
|
|
||
| for bin_idx, packed in enumerate(bins): | ||
| total_len = bin_lengths[bin_idx] | ||
| if drop_last and total_len < max_length: | ||
| continue | ||
| if fuse_positions_prob and rng.random() <= fuse_positions_prob: | ||
| packed["position_ids"] = list(range(len(packed["input_ids"]))) | ||
|
|
||
| # Add sum(seqlen^2) field | ||
| packed_batch["packed_seqlens_square_sum"].append( | ||
| sum([seqlen**2 for seqlen in packed["packed_sample_seqlens"]]) | ||
| ) | ||
|
|
||
| for k in keys: | ||
| packed_batch[k].append(packed[k]) | ||
|
|
||
| return packed_batch | ||
|
|
||
|
|
||
| def pack_sft_batch_naive( | ||
| batch: Dict[str, List[List[int]]], | ||
| max_length: int, | ||
| always_max_length: bool, | ||
| drop_last: bool, | ||
| fuse_positions_prob: float, | ||
| seed: int, | ||
| ) -> Packed_Data_Type: | ||
| keys = PACKING_KEYS | ||
| packed_batch: Packed_Data_Type = {k: [] for k in keys} | ||
| current_sample: Dict[str, List[int]] = {k: [] for k in keys} | ||
|
|
||
| packed_batch["packed_seqlens_square_sum"] = [] | ||
|
|
||
| rng = random.Random(seed) | ||
|
|
||
| def should_flush() -> bool: | ||
|
|
@@ -203,6 +322,12 @@ def flush() -> None: | |
| if len(current_sample["input_ids"]) > 0: | ||
| if fuse_positions_prob and rng.random() <= fuse_positions_prob: | ||
| current_sample["position_ids"] = list(range(len(current_sample["input_ids"]))) | ||
|
|
||
| # Add sum(seqlen^2) field | ||
| packed_batch["packed_seqlens_square_sum"].append( | ||
| sum([seqlen**2 for seqlen in current_sample["packed_sample_seqlens"]]) | ||
| ) | ||
|
|
||
| for k in keys: | ||
| packed_batch[k].append(current_sample[k]) | ||
| current_sample[k] = [] | ||
|
|
@@ -248,6 +373,24 @@ class SFTDataConfig(DataConfig): | |
| pack_samples: bool = False | ||
| """ Whether to pack multiple samples into samples up to size `max_length`. """ | ||
|
|
||
| pack_samples_mode: Literal["naive", "balance_length"] = "naive" | ||
| """ What packing algorithm to use. The default is a greedy packing algorithm""" | ||
|
|
||
| max_pack_batch_size: int = 10**4 | ||
| """ Maximum batch/chunk size for packing samples. Helps to avoid CPU OOM""" | ||
|
|
||
| dl_shuffle_samples: bool = True | ||
| """ Whether dataloader should shuffles samples. """ | ||
|
|
||
| sort_packed_samples: bool = False | ||
| """ Whether to sort packed samples. """ | ||
|
|
||
| sort_packed_samples_order: Literal["ascend", "descend"] = "descend" | ||
| """ Sorting order for packed samples. """ | ||
|
|
||
| sort_packed_samples_scope: Literal["local", "global"] = "local" | ||
|
||
| """ Sorting order for packed samples. """ | ||
|
|
||
| drop_last: bool = False | ||
| """ Whether to drop the last packed sample, which might be shorter than `max_length`. """ | ||
|
|
||
|
|
@@ -332,11 +475,18 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: | |
| dataset = repeat_dataset(dataset=dataset, max_length=self.config.max_length, num_proc=self.config.num_proc) | ||
|
|
||
| batch_size = len(dataset) // self.config.num_proc + 1 | ||
|
|
||
| # for huge datasets keep the bs to a sane size to avoid cpu-oom | ||
| batch_size = int(min(batch_size, 1e3)) | ||
| batch_size = int(min(batch_size, self.config.max_pack_batch_size)) | ||
|
|
||
| dataset = dataset.shuffle(seed=self.config.seed) | ||
| if self.config.pack_samples_mode == "balance_length": | ||
| packing_fn = pack_sft_batch_balance_length | ||
| else: | ||
| packing_fn = pack_sft_batch_naive | ||
|
|
||
| dataset = dataset.map( | ||
| lambda x: pack_sft_batch( | ||
| lambda x: packing_fn( | ||
| x, | ||
| max_length=self.config.max_length, | ||
| always_max_length=self.config.always_max_length, | ||
|
|
@@ -349,6 +499,21 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType: | |
| num_proc=self.config.num_proc, | ||
| desc="Packing dataset", | ||
| ) | ||
|
|
||
| if self.config.sort_packed_samples: | ||
| if self.config.sort_packed_samples_scope == "local": | ||
| dataset = dataset.map( | ||
| lambda x: sort_packed_sft_batch(x, reverse=(self.config.sort_packed_samples_order == "descend")), | ||
| batched=True, | ||
| batch_size=batch_size, | ||
| num_proc=self.config.num_proc, | ||
| desc="Local sorting dataset", | ||
| ) | ||
| else: | ||
| dataset = dataset.sort( | ||
| "packed_seqlens_square_sum", reverse=(self.config.sort_packed_samples_order == "descend") | ||
| ) | ||
|
|
||
| if len(dataset) < 1: | ||
| raise ValueError(f"No data left after packing dataset samples in {self.__class__.__name__}") | ||
| return dataset | ||
|
|
@@ -466,6 +631,10 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu | |
| return output | ||
|
|
||
| def create_dataloader(self, dataset: DatasetType) -> DataLoader: | ||
| dataloader = super().create_dataloader(dataset) | ||
| dataloader = ( | ||
| super().create_dataloader(dataset) | ||
| if self.config.dl_shuffle_samples | ||
| else super().create_dataloader_no_shuffle(dataset) | ||
| ) | ||
| dataloader.collate_fn = DataCollatorForCausalLM(tokenizer=self.tokenizer, config=self.config) | ||
| return dataloader | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to require a new method for this? Why not extend the existing
create_dataloader? This would avoid the need to add this new method for each data factory (_validate_class_method(cls, "create_dataloader_no_shuffle", ["self", "dataset"]))There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I previously extended the
create_dataloaderwith optionalshuffleflag. However, this caused UT failure onArcticTraining/arctic_training/data/factory.py
Line 75 in f472557
It seems the validation does not support optional args, or at least I don't know how to achieve that.
I also didn't want
shuffleto be mandatory forcreate_dataloader. But if this is preferred, I can make that change.What do you prefer?