@@ -402,3 +402,201 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
402
402
])
403
403
404
404
return config
405
+
406
+
407
+ @exp_factory .register_config_factory ('deit_imagenet_pretrain' )
408
+ def image_classification_imagenet_deit_pretrain () -> cfg .ExperimentConfig :
409
+ """Image classification on imagenet with vision transformer."""
410
+ train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
411
+ eval_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
412
+ label_smoothing = 0.1
413
+ steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
414
+ config = cfg .ExperimentConfig (
415
+ task = ImageClassificationTask (
416
+ model = ImageClassificationModel (
417
+ num_classes = 1001 ,
418
+ input_size = [224 , 224 , 3 ],
419
+ kernel_initializer = 'zeros' ,
420
+ backbone = backbones .Backbone (
421
+ type = 'vit' ,
422
+ vit = backbones .VisionTransformer (
423
+ model_name = 'vit-b16' ,
424
+ representation_size = 768 ,
425
+ init_stochastic_depth_rate = 0.1 ,
426
+ original_init = False ,
427
+ transformer = backbones .Transformer (
428
+ dropout_rate = 0.0 , attention_dropout_rate = 0.0 )))),
429
+ losses = Losses (
430
+ l2_weight_decay = 0.0 ,
431
+ label_smoothing = label_smoothing ,
432
+ one_hot = False ,
433
+ soft_labels = True ),
434
+ train_data = DataConfig (
435
+ input_path = os .path .join (IMAGENET_INPUT_PATH_BASE , 'train*' ),
436
+ is_training = True ,
437
+ global_batch_size = train_batch_size ,
438
+ aug_type = common .Augmentation (
439
+ type = 'randaug' ,
440
+ randaug = common .RandAugment (
441
+ magnitude = 9 , exclude_ops = ['Cutout' ])),
442
+ mixup_and_cutmix = common .MixupAndCutmix (
443
+ label_smoothing = label_smoothing )),
444
+ validation_data = DataConfig (
445
+ input_path = os .path .join (IMAGENET_INPUT_PATH_BASE , 'valid*' ),
446
+ is_training = False ,
447
+ global_batch_size = eval_batch_size )),
448
+ trainer = cfg .TrainerConfig (
449
+ steps_per_loop = steps_per_epoch ,
450
+ summary_interval = steps_per_epoch ,
451
+ checkpoint_interval = steps_per_epoch ,
452
+ train_steps = 300 * steps_per_epoch ,
453
+ validation_steps = IMAGENET_VAL_EXAMPLES // eval_batch_size ,
454
+ validation_interval = steps_per_epoch ,
455
+ optimizer_config = optimization .OptimizationConfig ({
456
+ 'optimizer' : {
457
+ 'type' : 'adamw' ,
458
+ 'adamw' : {
459
+ 'weight_decay_rate' : 0.05 ,
460
+ 'include_in_weight_decay' : r'.*(kernel|weight):0$' ,
461
+ 'gradient_clip_norm' : 0.0
462
+ }
463
+ },
464
+ 'learning_rate' : {
465
+ 'type' : 'cosine' ,
466
+ 'cosine' : {
467
+ 'initial_learning_rate' : 0.0005 * train_batch_size / 512 ,
468
+ 'decay_steps' : 300 * steps_per_epoch ,
469
+ }
470
+ },
471
+ 'warmup' : {
472
+ 'type' : 'linear' ,
473
+ 'linear' : {
474
+ 'warmup_steps' : 5 * steps_per_epoch ,
475
+ 'warmup_learning_rate' : 0
476
+ }
477
+ }
478
+ })),
479
+ restrictions = [
480
+ 'task.train_data.is_training != None' ,
481
+ 'task.validation_data.is_training != None'
482
+ ])
483
+
484
+ return config
485
+
486
+
487
+ @exp_factory .register_config_factory ('vit_imagenet_pretrain' )
488
+ def image_classification_imagenet_vit_pretrain () -> cfg .ExperimentConfig :
489
+ """Image classification on imagenet with vision transformer."""
490
+ train_batch_size = 4096
491
+ eval_batch_size = 4096
492
+ steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
493
+ config = cfg .ExperimentConfig (
494
+ task = ImageClassificationTask (
495
+ model = ImageClassificationModel (
496
+ num_classes = 1001 ,
497
+ input_size = [224 , 224 , 3 ],
498
+ kernel_initializer = 'zeros' ,
499
+ backbone = backbones .Backbone (
500
+ type = 'vit' ,
501
+ vit = backbones .VisionTransformer (
502
+ model_name = 'vit-b16' , representation_size = 768 ))),
503
+ losses = Losses (l2_weight_decay = 0.0 ),
504
+ train_data = DataConfig (
505
+ input_path = os .path .join (IMAGENET_INPUT_PATH_BASE , 'train*' ),
506
+ is_training = True ,
507
+ global_batch_size = train_batch_size ),
508
+ validation_data = DataConfig (
509
+ input_path = os .path .join (IMAGENET_INPUT_PATH_BASE , 'valid*' ),
510
+ is_training = False ,
511
+ global_batch_size = eval_batch_size )),
512
+ trainer = cfg .TrainerConfig (
513
+ steps_per_loop = steps_per_epoch ,
514
+ summary_interval = steps_per_epoch ,
515
+ checkpoint_interval = steps_per_epoch ,
516
+ train_steps = 300 * steps_per_epoch ,
517
+ validation_steps = IMAGENET_VAL_EXAMPLES // eval_batch_size ,
518
+ validation_interval = steps_per_epoch ,
519
+ optimizer_config = optimization .OptimizationConfig ({
520
+ 'optimizer' : {
521
+ 'type' : 'adamw' ,
522
+ 'adamw' : {
523
+ 'weight_decay_rate' : 0.3 ,
524
+ 'include_in_weight_decay' : r'.*(kernel|weight):0$' ,
525
+ 'gradient_clip_norm' : 0.0
526
+ }
527
+ },
528
+ 'learning_rate' : {
529
+ 'type' : 'cosine' ,
530
+ 'cosine' : {
531
+ 'initial_learning_rate' : 0.003 * train_batch_size / 4096 ,
532
+ 'decay_steps' : 300 * steps_per_epoch ,
533
+ }
534
+ },
535
+ 'warmup' : {
536
+ 'type' : 'linear' ,
537
+ 'linear' : {
538
+ 'warmup_steps' : 10000 ,
539
+ 'warmup_learning_rate' : 0
540
+ }
541
+ }
542
+ })),
543
+ restrictions = [
544
+ 'task.train_data.is_training != None' ,
545
+ 'task.validation_data.is_training != None'
546
+ ])
547
+
548
+ return config
549
+
550
+
551
+ @exp_factory .register_config_factory ('vit_imagenet_finetune' )
552
+ def image_classification_imagenet_vit_finetune () -> cfg .ExperimentConfig :
553
+ """Image classification on imagenet with vision transformer."""
554
+ train_batch_size = 512
555
+ eval_batch_size = 512
556
+ steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
557
+ config = cfg .ExperimentConfig (
558
+ task = ImageClassificationTask (
559
+ model = ImageClassificationModel (
560
+ num_classes = 1001 ,
561
+ input_size = [384 , 384 , 3 ],
562
+ backbone = backbones .Backbone (
563
+ type = 'vit' ,
564
+ vit = backbones .VisionTransformer (model_name = 'vit-b16' ))),
565
+ losses = Losses (l2_weight_decay = 0.0 ),
566
+ train_data = DataConfig (
567
+ input_path = os .path .join (IMAGENET_INPUT_PATH_BASE , 'train*' ),
568
+ is_training = True ,
569
+ global_batch_size = train_batch_size ),
570
+ validation_data = DataConfig (
571
+ input_path = os .path .join (IMAGENET_INPUT_PATH_BASE , 'valid*' ),
572
+ is_training = False ,
573
+ global_batch_size = eval_batch_size )),
574
+ trainer = cfg .TrainerConfig (
575
+ steps_per_loop = steps_per_epoch ,
576
+ summary_interval = steps_per_epoch ,
577
+ checkpoint_interval = steps_per_epoch ,
578
+ train_steps = 20000 ,
579
+ validation_steps = IMAGENET_VAL_EXAMPLES // eval_batch_size ,
580
+ validation_interval = steps_per_epoch ,
581
+ optimizer_config = optimization .OptimizationConfig ({
582
+ 'optimizer' : {
583
+ 'type' : 'sgd' ,
584
+ 'sgd' : {
585
+ 'momentum' : 0.9 ,
586
+ 'global_clipnorm' : 1.0 ,
587
+ }
588
+ },
589
+ 'learning_rate' : {
590
+ 'type' : 'cosine' ,
591
+ 'cosine' : {
592
+ 'initial_learning_rate' : 0.003 ,
593
+ 'decay_steps' : 20000 ,
594
+ }
595
+ }
596
+ })),
597
+ restrictions = [
598
+ 'task.train_data.is_training != None' ,
599
+ 'task.validation_data.is_training != None'
600
+ ])
601
+
602
+ return config
0 commit comments