1515
1616"""Dataloaders.""" 
1717
18+ from  functools  import  partial 
1819
20+ import  numpy  as  np 
1921import  torch 
20- import   random 
21- from  megatron  import  get_args 
22+ 
23+ from  megatron  import  get_args ,  get_tokenizer 
2224from  megatron  import  mpu 
25+ from  megatron .data .mtf_dataset  import  MTFDataset 
26+ 
27+ 
28+ def  pack_samples (items , max_seq_len : int , micro_batch_size : int , pad_token : int ):
29+     """ 
30+     Greedily packs samples. 
31+ 
32+     Items: 
33+         [ 
34+             { 
35+                 'input_tokens': array([6, 7]), 
36+                 'target_tokens': array([8]) 
37+             }, 
38+             { 
39+                 'input_tokens': array([3, 4]), 
40+                 'target_tokens': array([5]) 
41+             } 
42+         ] 
43+      
44+     Output: 
45+         decoder_target_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens. 
46+         decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents. 
47+         decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]: `0` depicts inputs, `1` depicts target. 
48+     """ 
49+ 
50+     decoder_target_tokens  =  np .full ((micro_batch_size , max_seq_len ), pad_token )
51+     decoder_segment_ids  =  np .zeros ((micro_batch_size , max_seq_len ))
52+     decoder_causal_attention  =  np .zeros ((micro_batch_size , max_seq_len ))
53+ 
54+     batch_num  =  0 
55+     # `0` is reserved for padding 
56+     item_num  =  1 
57+     cur_len  =  0 
58+     for  token_dict  in  items :
59+         input_token_len  =  len (token_dict ["input_tokens" ])
60+         target_token_len  =  len (token_dict ["target_tokens" ])
61+         total_len  =  input_token_len  +  target_token_len 
62+         if  cur_len  +  total_len  >  max_seq_len :
63+             len_diff  =  max_seq_len  -  cur_len 
64+             # Padding 
65+             if  len_diff  >  0 :
66+                 decoder_target_tokens [batch_num ][cur_len : max_seq_len ] =  pad_token 
67+                 decoder_segment_ids [batch_num ][cur_len : max_seq_len ] =  0 
68+                 decoder_causal_attention [batch_num ][cur_len : max_seq_len ] =  0 
69+             batch_num  +=  1 
70+             assert  batch_num  <  micro_batch_size 
71+             item_num  =  1 
72+             cur_len  =  0 
73+ 
74+         decoder_target_tokens [batch_num ][cur_len : cur_len  +  input_token_len ] =  token_dict ["input_tokens" ]
75+         decoder_target_tokens [batch_num ][cur_len  +  input_token_len : cur_len  +  total_len ] =  token_dict ["target_tokens" ]
76+         decoder_segment_ids [batch_num ][cur_len : cur_len  +  total_len ] =  item_num 
77+         decoder_causal_attention [batch_num ][cur_len : cur_len  +  input_token_len ] =  1   # input 
78+         decoder_causal_attention [batch_num ][cur_len  +  input_token_len : cur_len  +  total_len ] =  0   # target 
79+ 
80+         item_num  +=  1 
81+         cur_len  +=  total_len 
82+         assert  cur_len  <  max_seq_len 
83+ 
84+     return  {
85+         "decoder_target_tokens" : decoder_target_tokens ,
86+         "decoder_segment_ids" : decoder_segment_ids ,
87+         "decoder_causal_attention" : decoder_causal_attention ,
88+     }
2389
2490
2591def  build_pretraining_data_loader (dataset , consumed_samples , num_workers = None ):
@@ -44,18 +110,39 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
44110            micro_batch_size = args .micro_batch_size ,
45111            data_parallel_rank = mpu .get_data_parallel_rank (),
46112            data_parallel_size = mpu .get_data_parallel_world_size ())
113+     elif  args .dataloader_type  ==  'decoder_packed' :
114+         assert  isinstance (dataset , MTFDataset )
115+         batch_sampler  =  MegatronDecoderPackedText2TextRandomSampler (
116+             sequence_length = args .seq_length  +  1 ,
117+             dataset = dataset ,
118+             total_samples = len (dataset ),
119+             consumed_samples = consumed_samples ,
120+             micro_batch_size = args .micro_batch_size ,
121+             data_parallel_rank = mpu .get_data_parallel_rank (),
122+             data_parallel_size = mpu .get_data_parallel_world_size ())
47123    else :
48124        raise  Exception ('{} dataloader type is not supported.' .format (
49-                  args .dataloader_type ))
125+             args .dataloader_type ))
50126
51127    if  num_workers  is  None :
52128        num_workers  =  args .num_workers 
53129
130+     collate_fn  =  None 
131+     if  args .dataloader_type  ==  'decoder_packed' :
132+         assert  isinstance (dataset , MTFDataset )
133+         pad_token  =  get_tokenizer ().pad 
134+         collate_fn  =  partial (pack_samples , max_seq_len = args .seq_length  +  1 , micro_batch_size = args .micro_batch_size ,
135+                              pad_token = pad_token )
136+ 
54137    # Torch dataloader. 
55-     return  torch .utils .data .DataLoader (dataset ,
56-                                        batch_sampler = batch_sampler ,
57-                                        num_workers = num_workers ,
58-                                        pin_memory = True )
138+     return  torch .utils .data .DataLoader (
139+         dataset ,
140+         batch_sampler = batch_sampler ,
141+         num_workers = num_workers ,
142+         collate_fn = collate_fn ,
143+         pin_memory = True 
144+     )
145+ 
59146
60147class  MegatronPretrainingSampler :
61148
@@ -141,7 +228,7 @@ def __iter__(self):
141228
142229        # data sharding and random sampling 
143230        bucket_size  =  (self .total_samples  //  self .micro_batch_times_data_parallel_size ) \
144-                         *  self .micro_batch_size 
231+                       *  self .micro_batch_size 
145232        bucket_offset  =  current_epoch_samples  //  self .data_parallel_size 
146233        start_idx  =  self .data_parallel_rank  *  bucket_size 
147234
@@ -158,3 +245,76 @@ def __iter__(self):
158245                self .consumed_samples  +=  self .micro_batch_times_data_parallel_size 
159246                yield  batch 
160247                batch  =  []
248+ 
249+ 
250+ class  MegatronDecoderPackedText2TextRandomSampler (object ):
251+     """ 
252+     Converts a two stream dataset with `input_tokens` and `target_tokens` and creates a batch that should be greedily 
253+     packed to be passed onto the decoder model. 
254+ 
255+     To be used with `pack_samples` as collate_fn 
256+     """ 
257+ 
258+     def  __init__ (self , sequence_length , dataset , total_samples , consumed_samples , micro_batch_size ,
259+                  data_parallel_rank , data_parallel_size ):
260+         # Keep a copy of input params for later use. 
261+         self .dataset  =  dataset 
262+         self .sequence_length  =  sequence_length 
263+         self .total_samples  =  total_samples 
264+         self .consumed_samples  =  consumed_samples 
265+         self .micro_batch_size  =  micro_batch_size 
266+         self .data_parallel_rank  =  data_parallel_rank 
267+         self .data_parallel_size  =  data_parallel_size 
268+         self .micro_batch_times_data_parallel_size  =  \
269+             self .micro_batch_size  *  data_parallel_size 
270+         self .last_batch_size  =  \
271+             self .total_samples  %  self .micro_batch_times_data_parallel_size 
272+ 
273+         # Sanity checks. 
274+         assert  self .total_samples  >  0 , \
275+             'no sample to consume: {}' .format (self .total_samples )
276+         assert  self .micro_batch_size  >  0 
277+         assert  data_parallel_size  >  0 
278+         assert  self .data_parallel_rank  <  data_parallel_size , \
279+             'data_parallel_rank should be smaller than data size: {}, '  \
280+             '{}' .format (self .data_parallel_rank , data_parallel_size )
281+ 
282+     def  __len__ (self ):
283+         return  self .total_samples 
284+ 
285+     def  __iter__ (self ):
286+         active_total_samples  =  self .total_samples  -  self .last_batch_size 
287+         self .epoch  =  self .consumed_samples  //  active_total_samples 
288+         current_epoch_samples  =  self .consumed_samples  %  active_total_samples 
289+         assert  current_epoch_samples  %  self .micro_batch_times_data_parallel_size  ==  0 
290+ 
291+         # data sharding and random sampling 
292+         bucket_size  =  (self .total_samples  //  self .micro_batch_times_data_parallel_size ) \
293+                       *  self .micro_batch_size 
294+         bucket_offset  =  current_epoch_samples  //  self .data_parallel_size 
295+         start_idx  =  self .data_parallel_rank  *  bucket_size 
296+ 
297+         g  =  torch .Generator ()
298+         g .manual_seed (self .epoch )
299+ 
300+         random_idx  =  torch .randperm (bucket_size , generator = g ).tolist ()
301+         idx_range  =  [start_idx  +  x  for  x  in  random_idx [bucket_offset :]]
302+ 
303+         batch  =  []
304+         batch_count  =  0 
305+         token_lens  =  0 
306+         # Last batch if not complete will be dropped. 
307+         for  idx  in  idx_range :
308+             tok_len  =  len (self .dataset [idx ]['input_tokens' ]) +  len (self .dataset [idx ]['target_tokens' ])
309+             if  token_lens  +  tok_len  >  self .sequence_length :
310+                 batch_count  +=  1 
311+                 token_lens  =  0 
312+ 
313+             if  batch_count  ==  self .micro_batch_size :
314+                 self .consumed_samples  +=  self .micro_batch_times_data_parallel_size 
315+                 yield  batch 
316+                 batch_count  =  0 
317+                 batch  =  []
318+             else :
319+                 token_lens  +=  tok_len 
320+                 batch .append (idx )
0 commit comments