|
18 | 18 | from train_avd import train_avd |
19 | 19 | from reconstruction import reconstruction |
20 | 20 | import os |
| 21 | +from torchinfo import summary |
21 | 22 | import bitsandbytes as bnb |
22 | 23 |
|
23 | 24 | optimizer_choices = { |
|
37 | 38 | parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"]) |
38 | 39 | parser.add_argument("--log_dir", default='log', help="path to log into") |
39 | 40 | parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") |
40 | | - parser.add_argument("--optimizer_class", default="adam", choices=optimizer_choices.keys()) |
| 41 | + parser.add_argument("--detect_anomaly", action="store_true", help="detect anomaly in autograd") |
41 | 42 |
|
42 | 43 |
|
43 | 44 | opt = parser.parse_args() |
|
50 | 51 | log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]) |
51 | 52 | log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime()) |
52 | 53 |
|
| 54 | + if opt.detect_anomaly: |
| 55 | + torch.autograd.set_detect_anomaly(True) |
| 56 | + |
53 | 57 | inpainting = InpaintingNetwork(**config['model_params']['generator_params'], |
54 | 58 | **config['model_params']['common_params']) |
55 | 59 |
|
|
76 | 80 | if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): |
77 | 81 | copy(opt.config, log_dir) |
78 | 82 |
|
79 | | - optimizer_class = optimizer_choices[opt.optimizer_class] |
| 83 | + optimizer_class = optimizer_choices[config['train_params']['optimizer']] |
| 84 | + |
| 85 | + print("Inpainting Network:") |
| 86 | + summary(inpainting) |
| 87 | + print("Keypoint Detector:") |
| 88 | + summary(kp_detector) |
| 89 | + print("Dense Motion Network:") |
| 90 | + summary(dense_motion_network) |
| 91 | + if bg_predictor is not None: |
| 92 | + print("Background Predictor:") |
| 93 | + summary(bg_predictor) |
80 | 94 |
|
81 | 95 | if opt.mode == 'train': |
82 | 96 | print("Training...") |
|
90 | 104 | print("Reconstruction...") |
91 | 105 | #TODO: update to accelerate |
92 | 106 | reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset) |
| 107 | + |
0 commit comments