Skip to content

Commit 15693c7

Browse files
authored
Speed up for hybrid parallel (PaddlePaddle#1056)
1 parent 64b6f0f commit 15693c7

File tree

3 files changed

+24
-22
lines changed

3 files changed

+24
-22
lines changed

examples/language_model/gpt-3/dygraph/dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,10 @@ def build_dataset(index, name, num_samples):
286286
places=places,
287287
feed_list=data_holders,
288288
batch_sampler=batch_sampler,
289-
num_workers=0,
289+
num_workers=1,
290290
worker_init_fn=worker_init,
291291
# collate_fn=Tuple(Stack(), Stack(), Stack(), Stack(), Stack()),
292-
collate_fn=Tuple(Stack(), Stack(), Stack()),
292+
collate_fn=Tuple(Stack(), Stack(), Stack(), Stack()),
293293
return_list=False)
294294
return data_loader
295295

@@ -349,12 +349,12 @@ def _construct_sample(self, tokens):
349349
# The pad and eos tokens do not contribute the loss
350350
loss_mask = np.ones(seq_length, dtype="float32")
351351
loss_mask[np.where(np.array(tokens) == self.eos_id)] = 0.0
352-
# position_ids = np.arange(0, seq_length, dtype="int64")
352+
position_ids = np.arange(0, seq_length, dtype="int64")
353353

354354
# attention_mask = (attention_mask - 1.0) * 1e9
355355
# attention_mask = attention_mask.astype("float32")
356356
# return [tokens, loss_mask, attention_mask, position_ids, labels]
357-
return [tokens, loss_mask, labels]
357+
return [tokens, loss_mask, position_ids, labels]
358358

359359
def _get_single_sample_from_idx(self, doc_index_f, doc_index_l, offset_f,
360360
offset_l):

examples/language_model/gpt-3/dygraph/modeling.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,8 @@ def forward(self, input_ids, position_ids=None):
497497
position_embeddings = self.position_embeddings(position_ids)
498498
embeddings = input_embedings + position_embeddings
499499

500-
with get_rng_state_tracker().rng_state('global_seed'):
501-
embeddings = self.dropout(embeddings)
500+
#with get_rng_state_tracker().rng_state('global_seed'):
501+
embeddings = self.dropout(embeddings)
502502

503503
return embeddings
504504

@@ -754,23 +754,24 @@ def forward(self,
754754
input_ids=input_ids, position_ids=position_ids)
755755

756756
# TODO, use registered buffer
757-
causal_mask = paddle.tensor.triu(
758-
paddle.ones((paddle.shape(input_ids)[-1],
759-
paddle.shape(input_ids)[-1])) * -1e9,
760-
diagonal=1)
757+
# causal_mask = paddle.tensor.triu(
758+
# paddle.ones((paddle.shape(input_ids)[-1],
759+
# paddle.shape(input_ids)[-1])) * -1e9,
760+
# diagonal=1)
761761

762-
if attention_mask is not None:
763-
attention_mask = attention_mask + causal_mask
764-
else:
765-
attention_mask = causal_mask
762+
# if attention_mask is not None:
763+
# attention_mask = attention_mask + causal_mask
764+
# else:
765+
# attention_mask = causal_mask
766766

767767
# The tensor returned by triu not in static graph.
768-
attention_mask.stop_gradient = True
768+
# attention_mask.stop_gradient = True
769769

770770
encoder_outputs = self.decoder(
771771
embedding_output,
772772
memory=None,
773-
tgt_mask=attention_mask,
773+
# tgt_mask=attention_mask,
774+
tgt_mask=None,
774775
use_cache=use_cache,
775776
cache=cache)
776777
self.checkpoints.extend(self.decoder.checkpoints)

examples/language_model/gpt-3/dygraph/run_pretrain.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ def run_evaluate(args,
6666
all_loss = []
6767
local_time = time.time()
6868
for eval_step, batch in enumerate(data_loader):
69-
tokens, loss_mask, labels = batch
69+
tokens, loss_mask, position_ids, labels = batch
7070
if args.pp_degree < 2:
71-
preds = model(tokens)
71+
preds = model(tokens, position_ids)
7272
loss = criterion(preds, labels, loss_mask)
7373
else:
74-
data = [tokens, (labels, loss_mask)]
74+
data = [(tokens, position_ids), (labels, loss_mask)]
7575
loss = model.eval_batch(data, compute_loss=True)
7676

7777
all_loss.append(float(loss))
@@ -237,10 +237,11 @@ def do_train(args):
237237

238238
for step, batch in enumerate(train_data_loader()):
239239
global_step += 1
240-
tokens, loss_mask, labels = batch
240+
tokens, loss_mask, position_ids, labels = batch
241241

242242
loss_mask.stop_gradient = True
243243
labels.stop_gradient = True
244+
position_ids.stop_gradient = True
244245

245246
if args.pp_degree == 1:
246247
with paddle.amp.auto_cast(
@@ -252,7 +253,7 @@ def do_train(args):
252253
"reduce_sum", "c_softmax_with_cross_entropy",
253254
"c_embedding"
254255
]):
255-
preds = model(tokens)
256+
preds = model(tokens, position_ids)
256257
loss = criterion(preds, labels, loss_mask)
257258

258259
if args.use_amp:
@@ -267,7 +268,7 @@ def do_train(args):
267268
optimizer.clear_grad()
268269

269270
else:
270-
data = [tokens, (labels, loss_mask)]
271+
data = [(tokens, position_ids), (labels, loss_mask)]
271272
with paddle.amp.auto_cast(
272273
args.use_amp,
273274
custom_white_list=[

0 commit comments

Comments
 (0)