Skip to content

Commit 3301341

Browse files
thomasw21Lintang SutawikalintangsutawikaMuennighoff
authored andcommitted
MTF dataset and packing (bigscience-workshop#293)
Co-authored-by: Lintang Sutawika <[email protected]> Co-authored-by: lintangsutawika <[email protected]> Co-authored-by: Muennighoff <[email protected]>
1 parent 66b3212 commit 3301341

10 files changed

+693
-15
lines changed

megatron/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def _add_training_args(parser):
559559
'please refer https://github.com/facebookresearch/bitsandbytes.',
560560
dest='use_bnb_optimizer')
561561
group.add_argument('--dataloader-type', type=str, default=None,
562-
choices=['single', 'cyclic'],
562+
choices=['single', 'cyclic', 'decoder_packed'],
563563
help='Single pass vs multiple pass data loader')
564564
group.add_argument('--cpu-optimizer', action='store_true',
565565
help='Run optimizer on CPU')

megatron/data/data_samplers.py

Lines changed: 168 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,77 @@
1515

1616
"""Dataloaders."""
1717

18+
from functools import partial
1819

20+
import numpy as np
1921
import torch
20-
import random
21-
from megatron import get_args
22+
23+
from megatron import get_args, get_tokenizer
2224
from megatron import mpu
25+
from megatron.data.mtf_dataset import MTFDataset
26+
27+
28+
def pack_samples(items, max_seq_len: int, micro_batch_size: int, pad_token: int):
29+
"""
30+
Greedily packs samples.
31+
32+
Items:
33+
[
34+
{
35+
'input_tokens': array([6, 7]),
36+
'target_tokens': array([8])
37+
},
38+
{
39+
'input_tokens': array([3, 4]),
40+
'target_tokens': array([5])
41+
}
42+
]
43+
44+
Output:
45+
decoder_target_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
46+
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
47+
decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]: `0` depicts inputs, `1` depicts target.
48+
"""
49+
50+
decoder_target_tokens = np.full((micro_batch_size, max_seq_len), pad_token)
51+
decoder_segment_ids = np.zeros((micro_batch_size, max_seq_len))
52+
decoder_causal_attention = np.zeros((micro_batch_size, max_seq_len))
53+
54+
batch_num = 0
55+
# `0` is reserved for padding
56+
item_num = 1
57+
cur_len = 0
58+
for token_dict in items:
59+
input_token_len = len(token_dict["input_tokens"])
60+
target_token_len = len(token_dict["target_tokens"])
61+
total_len = input_token_len + target_token_len
62+
if cur_len + total_len > max_seq_len:
63+
len_diff = max_seq_len - cur_len
64+
# Padding
65+
if len_diff > 0:
66+
decoder_target_tokens[batch_num][cur_len: max_seq_len] = pad_token
67+
decoder_segment_ids[batch_num][cur_len: max_seq_len] = 0
68+
decoder_causal_attention[batch_num][cur_len: max_seq_len] = 0
69+
batch_num += 1
70+
assert batch_num < micro_batch_size
71+
item_num = 1
72+
cur_len = 0
73+
74+
decoder_target_tokens[batch_num][cur_len: cur_len + input_token_len] = token_dict["input_tokens"]
75+
decoder_target_tokens[batch_num][cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"]
76+
decoder_segment_ids[batch_num][cur_len: cur_len + total_len] = item_num
77+
decoder_causal_attention[batch_num][cur_len: cur_len + input_token_len] = 1 # input
78+
decoder_causal_attention[batch_num][cur_len + input_token_len: cur_len + total_len] = 0 # target
79+
80+
item_num += 1
81+
cur_len += total_len
82+
assert cur_len < max_seq_len
83+
84+
return {
85+
"decoder_target_tokens": decoder_target_tokens,
86+
"decoder_segment_ids": decoder_segment_ids,
87+
"decoder_causal_attention": decoder_causal_attention,
88+
}
2389

2490

2591
def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
@@ -44,18 +110,39 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
44110
micro_batch_size=args.micro_batch_size,
45111
data_parallel_rank=mpu.get_data_parallel_rank(),
46112
data_parallel_size=mpu.get_data_parallel_world_size())
113+
elif args.dataloader_type == 'decoder_packed':
114+
assert isinstance(dataset, MTFDataset)
115+
batch_sampler = MegatronDecoderPackedText2TextRandomSampler(
116+
sequence_length=args.seq_length + 1,
117+
dataset=dataset,
118+
total_samples=len(dataset),
119+
consumed_samples=consumed_samples,
120+
micro_batch_size=args.micro_batch_size,
121+
data_parallel_rank=mpu.get_data_parallel_rank(),
122+
data_parallel_size=mpu.get_data_parallel_world_size())
47123
else:
48124
raise Exception('{} dataloader type is not supported.'.format(
49-
args.dataloader_type))
125+
args.dataloader_type))
50126

51127
if num_workers is None:
52128
num_workers = args.num_workers
53129

130+
collate_fn = None
131+
if args.dataloader_type == 'decoder_packed':
132+
assert isinstance(dataset, MTFDataset)
133+
pad_token = get_tokenizer().pad
134+
collate_fn = partial(pack_samples, max_seq_len=args.seq_length + 1, micro_batch_size=args.micro_batch_size,
135+
pad_token=pad_token)
136+
54137
# Torch dataloader.
55-
return torch.utils.data.DataLoader(dataset,
56-
batch_sampler=batch_sampler,
57-
num_workers=num_workers,
58-
pin_memory=True)
138+
return torch.utils.data.DataLoader(
139+
dataset,
140+
batch_sampler=batch_sampler,
141+
num_workers=num_workers,
142+
collate_fn=collate_fn,
143+
pin_memory=True
144+
)
145+
59146

60147
class MegatronPretrainingSampler:
61148

@@ -141,7 +228,7 @@ def __iter__(self):
141228

142229
# data sharding and random sampling
143230
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
144-
* self.micro_batch_size
231+
* self.micro_batch_size
145232
bucket_offset = current_epoch_samples // self.data_parallel_size
146233
start_idx = self.data_parallel_rank * bucket_size
147234

@@ -158,3 +245,76 @@ def __iter__(self):
158245
self.consumed_samples += self.micro_batch_times_data_parallel_size
159246
yield batch
160247
batch = []
248+
249+
250+
class MegatronDecoderPackedText2TextRandomSampler(object):
251+
"""
252+
Converts a two stream dataset with `input_tokens` and `target_tokens` and creates a batch that should be greedily
253+
packed to be passed onto the decoder model.
254+
255+
To be used with `pack_samples` as collate_fn
256+
"""
257+
258+
def __init__(self, sequence_length, dataset, total_samples, consumed_samples, micro_batch_size,
259+
data_parallel_rank, data_parallel_size):
260+
# Keep a copy of input params for later use.
261+
self.dataset = dataset
262+
self.sequence_length = sequence_length
263+
self.total_samples = total_samples
264+
self.consumed_samples = consumed_samples
265+
self.micro_batch_size = micro_batch_size
266+
self.data_parallel_rank = data_parallel_rank
267+
self.data_parallel_size = data_parallel_size
268+
self.micro_batch_times_data_parallel_size = \
269+
self.micro_batch_size * data_parallel_size
270+
self.last_batch_size = \
271+
self.total_samples % self.micro_batch_times_data_parallel_size
272+
273+
# Sanity checks.
274+
assert self.total_samples > 0, \
275+
'no sample to consume: {}'.format(self.total_samples)
276+
assert self.micro_batch_size > 0
277+
assert data_parallel_size > 0
278+
assert self.data_parallel_rank < data_parallel_size, \
279+
'data_parallel_rank should be smaller than data size: {}, ' \
280+
'{}'.format(self.data_parallel_rank, data_parallel_size)
281+
282+
def __len__(self):
283+
return self.total_samples
284+
285+
def __iter__(self):
286+
active_total_samples = self.total_samples - self.last_batch_size
287+
self.epoch = self.consumed_samples // active_total_samples
288+
current_epoch_samples = self.consumed_samples % active_total_samples
289+
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
290+
291+
# data sharding and random sampling
292+
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
293+
* self.micro_batch_size
294+
bucket_offset = current_epoch_samples // self.data_parallel_size
295+
start_idx = self.data_parallel_rank * bucket_size
296+
297+
g = torch.Generator()
298+
g.manual_seed(self.epoch)
299+
300+
random_idx = torch.randperm(bucket_size, generator=g).tolist()
301+
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
302+
303+
batch = []
304+
batch_count = 0
305+
token_lens = 0
306+
# Last batch if not complete will be dropped.
307+
for idx in idx_range:
308+
tok_len = len(self.dataset[idx]['input_tokens']) + len(self.dataset[idx]['target_tokens'])
309+
if token_lens + tok_len > self.sequence_length:
310+
batch_count += 1
311+
token_lens = 0
312+
313+
if batch_count == self.micro_batch_size:
314+
self.consumed_samples += self.micro_batch_times_data_parallel_size
315+
yield batch
316+
batch_count = 0
317+
batch = []
318+
else:
319+
token_lens += tok_len
320+
batch.append(idx)

0 commit comments

Comments
 (0)