Skip to content

Commit a827235

Browse files
ZHUIZeyuChen
andauthored
opt electra training for amp. (PaddlePaddle#1018)
Co-authored-by: Zeyu Chen <[email protected]>
1 parent 45967dc commit a827235

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

examples/language_model/electra/run_pretrain.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -424,14 +424,23 @@ def do_train(args):
424424
# Loads or initializes a model.
425425
pretrained_models = list(tokenizer_class.pretrained_init_configuration.keys(
426426
))
427+
428+
def get_opt_config(model_cls, name):
429+
config = model_cls.pretrained_init_configuration[name]
430+
# Optimize for AMP.
431+
if "vocab_size" in config:
432+
if config["vocab_size"] % 8 != 0:
433+
config["vocab_size"] += 8 - (config["vocab_size"] % 8)
434+
return config
435+
427436
if args.model_name_or_path in pretrained_models:
428437
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
429438
generator = ElectraGenerator(
430-
ElectraModel(**model_class.pretrained_init_configuration[
431-
args.model_name_or_path + "-generator"]))
439+
ElectraModel(**get_opt_config(model_class, args.model_name_or_path +
440+
"-generator")))
432441
discriminator = ElectraDiscriminator(
433-
ElectraModel(**model_class.pretrained_init_configuration[
434-
args.model_name_or_path + "-discriminator"]))
442+
ElectraModel(**get_opt_config(model_class, args.model_name_or_path +
443+
"-discriminator")))
435444
model = model_class(generator, discriminator)
436445
args.init_from_ckpt = False
437446
else:
@@ -445,11 +454,11 @@ def do_train(args):
445454
model_name = config_dict["model_name"]
446455
if model_name in pretrained_models:
447456
generator = ElectraGenerator(
448-
ElectraModel(**model_class.pretrained_init_configuration[
449-
model_name + "-generator"]))
457+
ElectraModel(**get_opt_config(model_class, model_name +
458+
"-generator")))
450459
discriminator = ElectraDiscriminator(
451-
ElectraModel(**model_class.pretrained_init_configuration[
452-
model_name + "-discriminator"]))
460+
ElectraModel(**get_opt_config(model_class, model_name +
461+
"-discriminator")))
453462
model = model_class(generator, discriminator)
454463
model.set_state_dict(
455464
paddle.load(

0 commit comments

Comments
 (0)