Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions arctic_training/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _validate_subclass(cls) -> None:
_validate_class_method(cls, "process", ["self", "dataset"])
_validate_class_method(cls, "split_data", ["self", "training_data"])
_validate_class_method(cls, "create_dataloader", ["self", "dataset"])
_validate_class_method(cls, "create_dataloader_no_shuffle", ["self", "dataset"])

def __init__(self, trainer: "Trainer", config: Optional[DataConfig] = None) -> None:
if config is None:
Expand Down Expand Up @@ -226,14 +227,25 @@ def split_data(self, training_data: DatasetType) -> Tuple[DatasetType, Optional[

return training_data, evaluation_data

@callback_wrapper("create_dataloader")
def create_dataloader(self, dataset: DatasetType) -> DataLoader:
def _create_dataloader(self, dataset: DatasetType, sampler_shuffle: bool = True) -> DataLoader:
"""Create a torch DataLoader from the dataset."""
return DataLoader(
dataset,
batch_size=self.micro_batch_size,
sampler=DistributedSampler(dataset, num_replicas=self.world_size, rank=self.global_rank),
sampler=DistributedSampler(
dataset, num_replicas=self.world_size, rank=self.global_rank, shuffle=sampler_shuffle
),
num_workers=self.config.dl_num_workers,
persistent_workers=True,
drop_last=True,
)

@callback_wrapper("create_dataloader")
def create_dataloader(self, dataset: DatasetType) -> DataLoader:
"""Create a torch DataLoader from the dataset."""
return self._create_dataloader(dataset)

@callback_wrapper("create_dataloader_no_shuffle")
def create_dataloader_no_shuffle(self, dataset: DatasetType) -> DataLoader:
"""Create a torch DataLoader from the dataset."""
return self._create_dataloader(dataset, sampler_shuffle=False)
Comment on lines +248 to +251
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to require a new method for this? Why not extend the existing create_dataloader? This would avoid the need to add this new method for each data factory (_validate_class_method(cls, "create_dataloader_no_shuffle", ["self", "dataset"]))

def create_dataloader(self, dataset: DatasetType, shuffle: bool = True):
    return self._create_dataloader(sampler_shuffle=shuffle)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I previously extended the create_dataloader with optional shuffle flag. However, this caused UT failure on

_validate_class_method(cls, "create_dataloader", ["self", "dataset"])

It seems the validation does not support optional args, or at least I don't know how to achieve that.

I also didn't want shuffle to be mandatory for create_dataloader. But if this is preferred, I can make that change.

What do you prefer?

183 changes: 176 additions & 7 deletions arctic_training/data/sft_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import Union

import numpy as np
import torch
Expand All @@ -38,6 +39,15 @@
from arctic_training.data.utils import DatasetType

IGNORE_INDEX = -100
PACKING_KEYS = (
"input_ids",
"labels",
"position_ids",
"packed_sample_seqlens",
"attention_mask",
)

Packed_Data_Type = Dict[str, List[Union[List[int], int]]]


# this function is modified from TRL trl.trainer.utils.py
Expand Down Expand Up @@ -142,9 +152,11 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]:
if "position_ids" in instances[0]:
position_ids = [torch.tensor(example["position_ids"]) for example in instances]
packed_sample_seqlens = [example["packed_sample_seqlens"] for example in instances]
packed_seqlens_square_sum = [example["packed_seqlens_square_sum"] for example in instances]
else:
position_ids = [torch.tensor(list(range(len(example["input_ids"])))) for example in instances]
packed_sample_seqlens = [[len(example["input_ids"])] for example in instances]
packed_seqlens_square_sum = [-1 for example in instances]

fake_unpacked_long_seq = False
# fake_unpacked_long_seq = True
Expand Down Expand Up @@ -178,21 +190,128 @@ def __call__(self, instances: List[Dict]) -> Dict[str, torch.Tensor]:
"labels": labels,
"position_ids": position_ids,
"packed_sample_seqlens": packed_sample_seqlens,
"packed_seqlens_square_sum": packed_seqlens_square_sum,
}


def pack_sft_batch(
def sort_packed_sft_batch(batch: Packed_Data_Type, reverse: bool) -> Packed_Data_Type:
packed_list = []
packed_keys = list(batch.keys())

for idx in range(len(batch["input_ids"])):
packed_dict = {key: batch[key][idx] for key in packed_keys}
packed_list.append(packed_dict)

def sum_square_compare(packed_sample):
return sum([seqlen**2 for seqlen in packed_sample["packed_sample_seqlens"]])

packed_list.sort(key=sum_square_compare, reverse=reverse)

packed_batch: Packed_Data_Type = {k: [] for k in packed_keys}

for packed_sample in packed_list:
for key in packed_keys:
packed_batch[key].append(packed_sample[key])

return packed_batch


def pack_sft_batch_balance_length(
batch: Dict[str, List[List[int]]],
max_length: int,
always_max_length: bool,
drop_last: bool,
fuse_positions_prob: float,
seed: int,
) -> Dict[str, List[List[int]]]:
keys = ("input_ids", "labels", "position_ids", "packed_sample_seqlens", "attention_mask")
packed_batch: Dict[str, List[List[int]]] = {k: [] for k in keys}
) -> Packed_Data_Type:
keys = PACKING_KEYS
packed_batch: Packed_Data_Type = {k: [] for k in keys}
packed_batch["packed_seqlens_square_sum"] = []

rng = random.Random(seed)

# Best-fit-decreasing bin packing to maximize utilization of `max_length`.
# This packs multiple short samples within the provided batch into larger samples each trying to be as close as possible to max_length.
samples = list(zip(batch["input_ids"], batch["labels"], batch["attention_mask"]))
# Sort by length descending; tie-breaker is deterministic to keep runs reproducible.
sorted_indices = sorted(range(len(samples)), key=lambda i: len(samples[i][0]), reverse=True)

bins: List[Dict[str, List[int]]] = []
bin_lengths: List[int] = []

def start_new_bin() -> int:
bins.append({k: [] for k in keys})
bin_lengths.append(0)
return len(bins) - 1

for idx in sorted_indices:
input_ids, labels, attention_mask = samples[idx]
sample_len = len(input_ids)

# Find the bin that leaves the least remaining space after insertion.
best_bin = None
best_remaining = None
for bin_idx, current_len in enumerate(bin_lengths):
remaining = max_length - current_len
if remaining <= 0:
continue
if not always_max_length and sample_len > remaining:
continue
take_len = min(sample_len, remaining)
remaining_after = remaining - take_len
if best_remaining is None or remaining_after < best_remaining:
best_remaining = remaining_after
best_bin = bin_idx

if best_bin is None:
best_bin = start_new_bin()

target_bin = bins[best_bin]
remaining = max_length - bin_lengths[best_bin]
if remaining <= 0:
continue # should not happen, but guard against negative remaining
take_len = min(sample_len, remaining) if always_max_length else sample_len
take_len = min(take_len, remaining)

target_bin["input_ids"].extend(input_ids[:take_len])
target_bin["labels"].extend(labels[:take_len])
target_bin["attention_mask"].extend(attention_mask[:take_len])
target_bin["position_ids"].extend(range(take_len))
target_bin["packed_sample_seqlens"].append(take_len)
bin_lengths[best_bin] += take_len

for bin_idx, packed in enumerate(bins):
total_len = bin_lengths[bin_idx]
if drop_last and total_len < max_length:
continue
if fuse_positions_prob and rng.random() <= fuse_positions_prob:
packed["position_ids"] = list(range(len(packed["input_ids"])))

# Add sum(seqlen^2) field
packed_batch["packed_seqlens_square_sum"].append(
sum([seqlen**2 for seqlen in packed["packed_sample_seqlens"]])
)

for k in keys:
packed_batch[k].append(packed[k])

return packed_batch


def pack_sft_batch_naive(
batch: Dict[str, List[List[int]]],
max_length: int,
always_max_length: bool,
drop_last: bool,
fuse_positions_prob: float,
seed: int,
) -> Packed_Data_Type:
keys = PACKING_KEYS
packed_batch: Packed_Data_Type = {k: [] for k in keys}
current_sample: Dict[str, List[int]] = {k: [] for k in keys}

packed_batch["packed_seqlens_square_sum"] = []

rng = random.Random(seed)

def should_flush() -> bool:
Expand All @@ -203,6 +322,12 @@ def flush() -> None:
if len(current_sample["input_ids"]) > 0:
if fuse_positions_prob and rng.random() <= fuse_positions_prob:
current_sample["position_ids"] = list(range(len(current_sample["input_ids"])))

# Add sum(seqlen^2) field
packed_batch["packed_seqlens_square_sum"].append(
sum([seqlen**2 for seqlen in current_sample["packed_sample_seqlens"]])
)

for k in keys:
packed_batch[k].append(current_sample[k])
current_sample[k] = []
Expand Down Expand Up @@ -248,6 +373,24 @@ class SFTDataConfig(DataConfig):
pack_samples: bool = False
""" Whether to pack multiple samples into samples up to size `max_length`. """

pack_samples_mode: Literal["naive", "balance_length"] = "naive"
""" What packing algorithm to use. The default is a greedy packing algorithm"""

max_pack_batch_size: int = 10**4
""" Maximum batch/chunk size for packing samples. Helps to avoid CPU OOM"""

dl_shuffle_samples: bool = True
""" Whether dataloader should shuffles samples. """

sort_packed_samples: bool = False
""" Whether to sort packed samples. """

sort_packed_samples_order: Literal["ascend", "descend"] = "descend"
""" Sorting order for packed samples. """

sort_packed_samples_scope: Literal["local", "global"] = "local"
Copy link
Copy Markdown
Collaborator

@sfc-gh-sbekman sfc-gh-sbekman Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure local vs global is an intuitive self-documenting mnemonic in this context - requires doc reading to understand what each implies.

Perhaps batched vs all? that is sort each batch separately, vs sort all?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local to me implies gpu/rank-local or some such.

""" Sorting order for packed samples. """

drop_last: bool = False
""" Whether to drop the last packed sample, which might be shorter than `max_length`. """

Expand Down Expand Up @@ -332,11 +475,18 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType:
dataset = repeat_dataset(dataset=dataset, max_length=self.config.max_length, num_proc=self.config.num_proc)

batch_size = len(dataset) // self.config.num_proc + 1

# for huge datasets keep the bs to a sane size to avoid cpu-oom
batch_size = int(min(batch_size, 1e3))
batch_size = int(min(batch_size, self.config.max_pack_batch_size))

dataset = dataset.shuffle(seed=self.config.seed)
if self.config.pack_samples_mode == "balance_length":
packing_fn = pack_sft_batch_balance_length
else:
packing_fn = pack_sft_batch_naive

dataset = dataset.map(
lambda x: pack_sft_batch(
lambda x: packing_fn(
x,
max_length=self.config.max_length,
always_max_length=self.config.always_max_length,
Expand All @@ -349,6 +499,21 @@ def pack_dataset(self, dataset: DatasetType) -> DatasetType:
num_proc=self.config.num_proc,
desc="Packing dataset",
)

if self.config.sort_packed_samples:
if self.config.sort_packed_samples_scope == "local":
dataset = dataset.map(
lambda x: sort_packed_sft_batch(x, reverse=(self.config.sort_packed_samples_order == "descend")),
batched=True,
batch_size=batch_size,
num_proc=self.config.num_proc,
desc="Local sorting dataset",
)
else:
dataset = dataset.sort(
"packed_seqlens_square_sum", reverse=(self.config.sort_packed_samples_order == "descend")
)

if len(dataset) < 1:
raise ValueError(f"No data left after packing dataset samples in {self.__class__.__name__}")
return dataset
Expand Down Expand Up @@ -466,6 +631,10 @@ def get_masked_labels(conversation_ids: BatchEncoding, assistant_ranges: List[Tu
return output

def create_dataloader(self, dataset: DatasetType) -> DataLoader:
dataloader = super().create_dataloader(dataset)
dataloader = (
super().create_dataloader(dataset)
if self.config.dl_shuffle_samples
else super().create_dataloader_no_shuffle(dataset)
)
dataloader.collate_fn = DataCollatorForCausalLM(tokenizer=self.tokenizer, config=self.config)
return dataloader
10 changes: 10 additions & 0 deletions arctic_training/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def print_summary(self, prefix: str = "train") -> None:
)
self.values["seqlen_total"] = seqlen_subtotal

# self.seqlens is list[list[int]]
self.values["seqlen_square_sum"] = sum([sum([len * len for len in seqlens]) for seqlens in self.seqlens])

if "loss" in self.values:
loss = sum(gather_object(self.values["loss"], self.trainer.world_size)) / self.trainer.world_size
self.summary_dict["loss"] = loss
Expand Down Expand Up @@ -162,6 +165,9 @@ def print_summary(self, prefix: str = "train") -> None:
seq_len_total = sum(gather_object(self.values["seqlen_total"], self.trainer.world_size))
self.summary_dict["seqlen"] = seq_len_total / self.trainer.world_size

seqlen_square_sum_total = sum(gather_object(self.values["seqlen_square_sum"], self.trainer.world_size))
self.summary_dict["seqlen_square_sum"] = seqlen_square_sum_total / self.trainer.world_size

if "step_time" in self.values:
step_time_total = sum(gather_object(self.values["step_time"], self.trainer.world_size))
self.summary_dict["step_time"] = step_time_total / self.trainer.world_size
Expand All @@ -186,6 +192,10 @@ def print_summary(self, prefix: str = "train") -> None:
summary_str += f" | lr: {self.summary_dict['lr']:.3E}"
if "seqlen" in self.summary_dict:
summary_str += f" | seqlen: {human_format_base10_number(self.summary_dict['seqlen'])}"
if "seqlen_square_sum" in self.summary_dict:
summary_str += (
f" | seqlen_square_sum: {human_format_base10_number(self.summary_dict['seqlen_square_sum'])}"
)
if "step_time" in self.summary_dict:
summary_str += f" | step time: {human_format_secs(self.summary_dict['step_time'])}"
if "step_tflops" in self.summary_dict:
Expand Down
1 change: 1 addition & 0 deletions arctic_training/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def epoch(self) -> None:
if "packed_sample_seqlens" in batch and "flash_attention" in self.config.model.attn_implementation:
# deal correctly with packed samples under FA2/FA3, by calculating each seqlen tflos separately
sample_seqlens = batch.pop("packed_sample_seqlens")
batch.pop("packed_seqlens_square_sum")
else:
sample_seqlens = [
[len(batch["input_ids"][idx]) * self.config.sequence_parallel_size]
Expand Down