1515
1616accelerator = Accelerator ()
1717
18+
1819def train (config , inpainting_network , kp_detector , bg_predictor , dense_motion_network , checkpoint , log_dir , dataset ,
1920 optimizer_class = torch .optim .Adam
2021 ):
@@ -44,23 +45,35 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne
4445 else :
4546 start_epoch = 0
4647
47-
48-
48+ freeze_kp_detector = train_params .get ('freeze_kp_detector' , False )
49+ freeze_bg_predictor = train_params .get ('freeze_bg_predictor' , False )
50+ if freeze_kp_detector :
51+ print ('freeze kp detector' )
52+ kp_detector .eval ()
53+ for param in kp_detector .parameters ():
54+ param .requires_grad = False
55+ if freeze_bg_predictor :
56+ print ('freeze bg predictor' )
57+ bg_predictor .eval ()
58+ for param in bg_predictor .parameters ():
59+ param .requires_grad = False
4960
5061 if 'num_repeats' in train_params or train_params ['num_repeats' ] != 1 :
5162 dataset = DatasetRepeater (dataset , train_params ['num_repeats' ])
5263 dataloader = DataLoader (dataset , batch_size = train_params ['batch_size' ], shuffle = True ,
5364 num_workers = train_params ['dataloader_workers' ], drop_last = True )
5465
5566 scheduler_optimizer = OneCycleLR (optimizer , max_lr = train_params ['lr_generator' ],
56- total_steps = (len (dataset ) // train_params ['batch_size' ]) * train_params ['num_epochs' ],
57- last_epoch = start_epoch - 1 )
67+ total_steps = (len (dataset ) // train_params ['batch_size' ]) * train_params [
68+ 'num_epochs' ],
69+ last_epoch = start_epoch - 1 )
5870
5971 scheduler_bg_predictor = None
6072 if bg_predictor :
6173 scheduler_bg_predictor = OneCycleLR (optimizer_bg_predictor , max_lr = train_params ['lr_generator' ],
62- total_steps = (len (dataset ) // train_params ['batch_size' ]) * train_params ['num_epochs' ],
63- last_epoch = start_epoch - 1 )
74+ total_steps = (len (dataset ) // train_params ['batch_size' ]) * train_params [
75+ 'num_epochs' ],
76+ last_epoch = start_epoch - 1 )
6477 bg_predictor , optimizer_bg_predictor = accelerator .prepare (bg_predictor , optimizer_bg_predictor )
6578
6679 generator_full = GeneratorFullModel (kp_detector , bg_predictor , dense_motion_network , inpainting_network ,
@@ -75,16 +88,21 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne
7588 if train_params .get ('visualize_model' , False ):
7689 # visualize graph
7790 sample = next (iter (dataloader ))
78- draw_graph (generator_full , input_data = [sample , 100 ], save_graph = True , directory = log_dir , graph_name = 'generator_full' )
79- draw_graph (kp_detector , input_data = [sample ['driving' ]], save_graph = True , directory = log_dir , graph_name = 'kp_detector' )
91+ draw_graph (generator_full , input_data = [sample , 100 ], save_graph = True , directory = log_dir ,
92+ graph_name = 'generator_full' )
93+ draw_graph (kp_detector , input_data = [sample ['driving' ]], save_graph = True , directory = log_dir ,
94+ graph_name = 'kp_detector' )
8095 kp_driving = kp_detector (sample ['driving' ])
8196 kp_source = kp_detector (sample ['source' ])
8297 bg_param = bg_predictor (sample ['source' ], sample ['driving' ])
83- dense_motion_param = {'source_image' : sample ['source' ], 'kp_driving' : kp_driving , 'kp_source' : kp_source , 'bg_param' : bg_param ,
84- 'dropout_flag' : False , 'dropout_p' : 0.0 }
98+ dense_motion_param = {'source_image' : sample ['source' ], 'kp_driving' : kp_driving , 'kp_source' : kp_source ,
99+ 'bg_param' : bg_param ,
100+ 'dropout_flag' : False , 'dropout_p' : 0.0 }
85101 dense_motion = dense_motion_network (** dense_motion_param )
86- draw_graph (dense_motion_network , input_data = dense_motion_param , save_graph = True , directory = log_dir , graph_name = 'dense_motion_network' )
87- draw_graph (inpainting_network , input_data = [sample ['source' ], dense_motion ], save_graph = True , directory = log_dir , graph_name = 'inpainting_network' )
102+ draw_graph (dense_motion_network , input_data = dense_motion_param , save_graph = True , directory = log_dir ,
103+ graph_name = 'dense_motion_network' )
104+ draw_graph (inpainting_network , input_data = [sample ['source' ], dense_motion ], save_graph = True , directory = log_dir ,
105+ graph_name = 'inpainting_network' )
88106
89107 with Logger (log_dir = log_dir , visualizer_params = config ['visualizer_params' ],
90108 checkpoint_freq = train_params ['checkpoint_freq' ],
@@ -100,14 +118,18 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne
100118
101119 clip_grad_norm_ (kp_detector .parameters (), max_norm = 10 , norm_type = math .inf )
102120 clip_grad_norm_ (dense_motion_network .parameters (), max_norm = 10 , norm_type = math .inf )
103- if bg_predictor and epoch >= bg_start :
121+ if bg_predictor and epoch >= bg_start and not freeze_bg_predictor :
104122 clip_grad_norm_ (bg_predictor .parameters (), max_norm = 10 , norm_type = math .inf )
105123
106124 optimizer .step ()
107- optimizer . zero_grad ()
108- if bg_predictor and epoch >= bg_start :
125+
126+ if bg_predictor and epoch >= bg_start and not freeze_bg_predictor :
109127 optimizer_bg_predictor .step ()
110128 optimizer_bg_predictor .zero_grad ()
129+ scheduler_bg_predictor .step ()
130+
131+ optimizer .zero_grad ()
132+ scheduler_optimizer .step ()
111133
112134 losses = {key : value .mean ().detach ().data .cpu ().numpy () for key , value in losses_generator .items ()}
113135 lrs = {
@@ -116,23 +138,15 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne
116138 }
117139 logger .log_iter (losses = losses , others = lrs )
118140
119- scheduler_optimizer .step ()
120- if bg_predictor :
121- scheduler_bg_predictor .step ()
141+
122142
123143 model_save = {
124- 'inpainting_network' : inpainting_network ,
125- 'dense_motion_network' : dense_motion_network ,
126- 'kp_detector' : kp_detector ,
144+ 'inpainting_network' : accelerator . unwrap_model ( inpainting_network ) ,
145+ 'dense_motion_network' : accelerator . unwrap_model ( dense_motion_network ) ,
146+ 'kp_detector' : accelerator . unwrap_model ( kp_detector ) ,
127147 'optimizer' : optimizer ,
148+ 'bg_predictor' : accelerator .unwrap_model (bg_predictor ) if bg_predictor else None ,
149+ 'optimizer_bg_predictor' : optimizer_bg_predictor
128150 }
129- if bg_predictor and epoch >= bg_start :
130- model_save ['bg_predictor' ] = bg_predictor
131- model_save ['optimizer_bg_predictor' ] = optimizer_bg_predictor
132-
133- accelerator .save_state (log_dir )
134-
135151
136152 logger .log_epoch (epoch , model_save , inp = x , out = generated )
137-
138-
0 commit comments