|
1 | 1 | import torch |
2 | 2 |
|
3 | 3 | from network import AvatarNet |
4 | | -from utils import imload, imsave |
| 4 | +from utils import imload, imsave, maskload |
5 | 5 |
|
6 | 6 |
|
7 | 7 | def network_test(args): |
8 | 8 | # set device |
9 | | - device = torch.device("cuda" if args.cuda_device_no >= 0 else "cpu") |
| 9 | + device = torch.device('cuda' if args.gpu_no >= 0 else 'cpu') |
| 10 | + |
| 11 | + # load check point |
| 12 | + check_point = torch.load(args.check_point) |
10 | 13 |
|
11 | 14 | # load network |
12 | 15 | network = AvatarNet(args.layers) |
13 | | - network.load_state_dict(torch.load(args.model_load_path)) |
| 16 | + network.load_state_dict(check_point['state_dict']) |
14 | 17 | network = network.to(device) |
15 | 18 |
|
16 | 19 | # load target images |
17 | | - content_image = imload(args.test_content_image_path, args.imsize, args.cropsize) |
18 | | - style_image = imload(args.test_style_image_path, args.imsize, args.cropsize) |
19 | | - content_image, style_image = content_image.to(device), style_image.to(device) |
20 | | - |
| 20 | + content_img = imload(args.content, args.imsize, args.cropsize).to(device) |
| 21 | + style_imgs = [imload(style, args.imsize, args.cropsize, args.cencrop).to(device) for style in args.style] |
| 22 | + masks = None |
| 23 | + if args.mask: |
| 24 | + masks = [maskload(mask).to(device) for mask in args.mask] |
| 25 | + |
21 | 26 | # stylize image |
22 | 27 | with torch.no_grad(): |
23 | | - output_image = network(content_image, style_image, args.train_flag, args.style_strength, args.patch_size, args.patch_stride) |
| 28 | + stylized_img = network(content_img, style_imgs, args.style_strength, args.patch_size, args.patch_stride, |
| 29 | + masks, args.interpolation_weights, args.preserve_color, False) |
24 | 30 |
|
25 | | - imsave(output_image.data, args.output_image_path) |
26 | | - |
27 | | - return output_image |
| 31 | + imsave(stylized_img, 'stylized_image.jpg') |
| 32 | + |
| 33 | + return None |
0 commit comments