|
1 | 1 | import torch |
2 | 2 | import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
3 | 4 |
|
4 | 5 | import torchvision |
5 | 6 | import torchvision.models as models |
6 | 7 |
|
| 8 | +from wct import batch_wct |
7 | 9 | from style_decorator import StyleDecorator |
8 | 10 |
|
9 | 11 | class AvatarNet(nn.Module): |
10 | | - def __init__(self, layers): |
| 12 | + def __init__(self, layers=[1, 6, 11, 20]): |
11 | 13 | super(AvatarNet, self).__init__() |
12 | | - vgg = models.vgg19(pretrained=True).features |
13 | | - |
14 | | - # get network layers |
15 | | - self.encoders = get_encoder(vgg, layers) |
16 | | - self.decoders = get_decoder(vgg, layers) |
| 14 | + self.encoder = Encoder(layers) |
| 15 | + self.decoder = Decoder(layers) |
17 | 16 |
|
18 | 17 | self.adain = AdaIN() |
19 | 18 | self.decorator = StyleDecorator() |
20 | 19 |
|
21 | | - def forward(self, c, s, train_flag=False, style_strength=1.0, patch_size=3, patch_stride=1): |
| 20 | + def forward(self, content, styles, style_strength=1.0, patch_size=3, patch_stride=1, masks=None, interpolation_weights=None, preserve_color=False, train=False): |
| 21 | + if interpolation_weights is None: |
| 22 | + interpolation_weights = [1/len(styles)] * len(styles) |
| 23 | + if masks is None: |
| 24 | + masks = [1] * len(styles) |
22 | 25 |
|
23 | 26 | # encode content image |
24 | | - for encoder in self.encoders: |
25 | | - c = encoder(c) |
26 | | - |
27 | | - # encode style image and extract multi-scale feature for AdaIN in decoder network |
28 | | - features = [] |
29 | | - for encoder in self.encoders: |
30 | | - s = encoder(s) |
31 | | - features.append(s) |
32 | | - |
33 | | - # delete last style feature |
34 | | - del features[-1] |
35 | | - |
36 | | - if not train_flag: |
37 | | - c = self.decorator(c, s, style_strength, patch_size, patch_stride) |
38 | | - |
39 | | - for decoder in self.decoders: |
40 | | - c = decoder(c) |
41 | | - if features: |
42 | | - c = self.adain(c, features.pop()) |
| 27 | + content_feature = self.encoder(content) |
| 28 | + style_features = [] |
| 29 | + for style in styles: |
| 30 | + style_features.append(self.encoder(style)) |
| 31 | + |
| 32 | + if not train: |
| 33 | + transformed_feature = [] |
| 34 | + for style_feature, interpolation_weight, mask in zip(style_features, interpolation_weights, masks): |
| 35 | + if isinstance(mask, torch.Tensor): |
| 36 | + b, c, h, w = content_feature[-1].size() |
| 37 | + mask = F.interpolate(mask, size=(h, w)) |
| 38 | + transformed_feature.append(self.decorator(content_feature[-1], style_feature[-1], style_strength, patch_size, patch_stride) * interpolation_weight * mask) |
| 39 | + transformed_feature = sum(transformed_feature) |
| 40 | + |
| 41 | + else: |
| 42 | + transformed_feature = content_feature[-1] |
| 43 | + |
| 44 | + # re-ordering style features for transferring feature during decoding |
| 45 | + style_features = [style_feature[:-1][::-1] for style_feature in style_features] |
| 46 | + |
| 47 | + stylized_image = self.decoder(transformed_feature, style_features, masks, interpolation_weights) |
43 | 48 |
|
44 | | - return c |
| 49 | + return stylized_image |
45 | 50 |
|
46 | | -# Adaptive Instance Normalization |
47 | | -## ref: https://arxiv.org/abs/1703.06868 |
48 | 51 | class AdaIN(nn.Module): |
49 | | - def __init__(self, ): |
| 52 | + def __init__(self): |
50 | 53 | super(AdaIN, self).__init__() |
| 54 | + |
| 55 | + def forward(self, content, style, style_strength=1.0, eps=1e-5): |
| 56 | + b, c, h, w = content.size() |
51 | 57 |
|
52 | | - def forward(self, x, t, eps=1e-5): |
53 | | - b, c, h, w = x.size() |
| 58 | + content_std, content_mean = torch.std_mean(content.view(b, c, -1), dim=2, keepdim=True) |
| 59 | + style_std, style_mean = torch.std_mean(style.view(b, c, -1), dim=2, keepdim=True) |
| 60 | + |
| 61 | + normalized_content = (content.view(b, c, -1) - content_mean)/(content_std+eps) |
54 | 62 |
|
55 | | - x_mean = torch.mean(x.view(b, c, h*w), dim=2, keepdim=True) |
56 | | - x_std = torch.std(x.view(b, c, h*w), dim=2, keepdim=True) |
| 63 | + stylized_content = (normalized_content * style_std) + style_mean |
| 64 | + |
| 65 | + output = (1-style_strength)*content + style_strength*stylized_content.view(b, c, h, w) |
| 66 | + return output |
| 67 | + |
| 68 | +class Encoder(nn.Module): |
| 69 | + def __init__(self, layers=[1, 6, 11, 20]): |
| 70 | + super(Encoder, self).__init__() |
| 71 | + vgg = torchvision.models.vgg19(pretrained=True).features |
57 | 72 |
|
58 | | - t_b, t_c, t_h, t_w = t.size() |
59 | | - t_mean = torch.mean(t.view(t_b, t_c, t_h*t_w), dim=2, keepdim=True) |
60 | | - t_std = torch.std(t.view(t_b, t_c, t_h*t_w), dim=2, keepdim=True) |
| 73 | + self.encoder = nn.ModuleList() |
| 74 | + temp_seq = nn.Sequential() |
| 75 | + for i in range(max(layers)+1): |
| 76 | + temp_seq.add_module(str(i), vgg[i]) |
| 77 | + if i in layers: |
| 78 | + self.encoder.append(temp_seq) |
| 79 | + temp_seq = nn.Sequential() |
| 80 | + |
| 81 | + def forward(self, x): |
| 82 | + features = [] |
| 83 | + for layer in self.encoder: |
| 84 | + x = layer(x) |
| 85 | + features.append(x) |
| 86 | + return features |
| 87 | + |
| 88 | +class Decoder(nn.Module): |
| 89 | + def __init__(self, layers=[1, 6, 11, 20], transformers=[AdaIN(), AdaIN(), AdaIN(), None]): |
| 90 | + super(Decoder, self).__init__() |
| 91 | + vgg = torchvision.models.vgg19(pretrained=False).features |
| 92 | + self.transformers = transformers |
61 | 93 |
|
62 | | - x_ = ((x.view(b, c, h*w) - x_mean)/(x_std + eps))*t_std + t_mean |
| 94 | + self.decoder = nn.ModuleList() |
| 95 | + temp_seq = nn.Sequential() |
| 96 | + count = 0 |
| 97 | + for i in range(max(layers)-1, -1, -1): |
| 98 | + if isinstance(vgg[i], nn.Conv2d): |
| 99 | + # get number of in/out channels |
| 100 | + out_channels = vgg[i].in_channels |
| 101 | + in_channels = vgg[i].out_channels |
| 102 | + kernel_size = vgg[i].kernel_size |
| 103 | + |
| 104 | + # make a [reflection pad + convolution + relu] layer |
| 105 | + temp_seq.add_module(str(count), nn.ReflectionPad2d(padding=(1,1,1,1))) |
| 106 | + count += 1 |
| 107 | + temp_seq.add_module(str(count), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size)) |
| 108 | + count += 1 |
| 109 | + temp_seq.add_module(str(count), nn.ReLU()) |
| 110 | + count += 1 |
| 111 | + |
| 112 | + # change down-sampling(MaxPooling) --> upsampling |
| 113 | + elif isinstance(vgg[i], nn.MaxPool2d): |
| 114 | + temp_seq.add_module(str(count), nn.Upsample(scale_factor=2)) |
| 115 | + count += 1 |
| 116 | + |
| 117 | + if i in layers: |
| 118 | + self.decoder.append(temp_seq) |
| 119 | + temp_seq = nn.Sequential() |
| 120 | + |
| 121 | + # append last conv layers without ReLU activation |
| 122 | + self.decoder.append(temp_seq[:-1]) |
63 | 123 |
|
64 | | - return x_.view(b, c, h, w) |
65 | | - |
66 | | - |
67 | | -# get network from vgg feature network |
68 | | -def get_encoder(vgg, layers): |
69 | | - encoder = nn.ModuleList() |
70 | | - temp_seq = nn.Sequential() |
71 | | - for i in range(max(layers)+1): |
72 | | - temp_seq.add_module(str(i), vgg[i]) |
73 | | - if i in layers: |
74 | | - encoder.append(temp_seq) |
75 | | - temp_seq = nn.Sequential() |
76 | | - |
77 | | - return encoder |
78 | | - |
79 | | -# get mirroed vgg feature network |
80 | | -def get_decoder(vgg, layers): |
81 | | - decoder = nn.ModuleList() |
82 | | - temp_seq = nn.Sequential() |
83 | | - count = 0 |
84 | | - for i in range(max(layers)-1, -1, -1): |
85 | | - if isinstance(vgg[i], nn.Conv2d): |
86 | | - # get number of in/out channels |
87 | | - out_channels = vgg[i].in_channels |
88 | | - in_channels = vgg[i].out_channels |
89 | | - kernel_size = vgg[i].kernel_size |
90 | | - |
91 | | - # make a [reflection pad + convolution + relu] layer |
92 | | - temp_seq.add_module(str(count), nn.ReflectionPad2d(padding=(1,1,1,1))) |
93 | | - count += 1 |
94 | | - temp_seq.add_module(str(count), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size)) |
95 | | - count += 1 |
96 | | - temp_seq.add_module(str(count), nn.ReLU()) |
97 | | - count += 1 |
98 | | - |
99 | | - # change down-sampling(MaxPooling) --> upsampling |
100 | | - elif isinstance(vgg[i], nn.MaxPool2d): |
101 | | - temp_seq.add_module(str(count), nn.Upsample(scale_factor=2)) |
102 | | - count += 1 |
103 | | - |
104 | | - if i in layers: |
105 | | - decoder.append(temp_seq) |
106 | | - temp_seq = nn.Sequential() |
107 | | - |
108 | | - # append last conv layers without ReLU activation |
109 | | - decoder.append(temp_seq[:-1]) |
110 | | - return decoder |
| 124 | + def forward(self, x, styles, masks=None, interpolation_weights=None): |
| 125 | + if interpolation_weights is None: |
| 126 | + interpolation_weights = [1/len(styles)] * len(styles) |
| 127 | + if masks is None: |
| 128 | + masks = [1] * len(styles) |
| 129 | + |
| 130 | + y = x |
| 131 | + for i, layer in enumerate(self.decoder): |
| 132 | + y = layer(y) |
| 133 | + |
| 134 | + if self.transformers[i]: |
| 135 | + transformed_feature = [] |
| 136 | + for style, interpolation_weight, mask in zip(styles, interpolation_weights, masks): |
| 137 | + if isinstance(mask, torch.Tensor): |
| 138 | + b, c, h, w = y.size() |
| 139 | + mask = F.interpolate(mask, size=(h, w)) |
| 140 | + transformed_feature.append(self.transformers[i](y, style[i]) * interpolation_weight * mask) |
| 141 | + y = sum(transformed_feature) |
| 142 | + |
| 143 | + return y |
0 commit comments