@@ -424,14 +424,23 @@ def do_train(args):
424
424
# Loads or initializes a model.
425
425
pretrained_models = list (tokenizer_class .pretrained_init_configuration .keys (
426
426
))
427
+
428
+ def get_opt_config (model_cls , name ):
429
+ config = model_cls .pretrained_init_configuration [name ]
430
+ # Optimize for AMP.
431
+ if "vocab_size" in config :
432
+ if config ["vocab_size" ] % 8 != 0 :
433
+ config ["vocab_size" ] += 8 - (config ["vocab_size" ] % 8 )
434
+ return config
435
+
427
436
if args .model_name_or_path in pretrained_models :
428
437
tokenizer = tokenizer_class .from_pretrained (args .model_name_or_path )
429
438
generator = ElectraGenerator (
430
- ElectraModel (** model_class . pretrained_init_configuration [
431
- args . model_name_or_path + "-generator" ] ))
439
+ ElectraModel (** get_opt_config ( model_class , args . model_name_or_path +
440
+ "-generator" ) ))
432
441
discriminator = ElectraDiscriminator (
433
- ElectraModel (** model_class . pretrained_init_configuration [
434
- args . model_name_or_path + "-discriminator" ] ))
442
+ ElectraModel (** get_opt_config ( model_class , args . model_name_or_path +
443
+ "-discriminator" ) ))
435
444
model = model_class (generator , discriminator )
436
445
args .init_from_ckpt = False
437
446
else :
@@ -445,11 +454,11 @@ def do_train(args):
445
454
model_name = config_dict ["model_name" ]
446
455
if model_name in pretrained_models :
447
456
generator = ElectraGenerator (
448
- ElectraModel (** model_class . pretrained_init_configuration [
449
- model_name + "-generator" ] ))
457
+ ElectraModel (** get_opt_config ( model_class , model_name +
458
+ "-generator" ) ))
450
459
discriminator = ElectraDiscriminator (
451
- ElectraModel (** model_class . pretrained_init_configuration [
452
- model_name + "-discriminator" ] ))
460
+ ElectraModel (** get_opt_config ( model_class , model_name +
461
+ "-discriminator" ) ))
453
462
model = model_class (generator , discriminator )
454
463
model .set_state_dict (
455
464
paddle .load (
0 commit comments