Skip to content

Commit 16a2798

Browse files
committed
update the Avatar-Net for masked stylization, multi-stylization.
1 parent a9fd25a commit 16a2798

File tree

1 file changed

+118
-85
lines changed

1 file changed

+118
-85
lines changed

network.py

Lines changed: 118 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,143 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
34

45
import torchvision
56
import torchvision.models as models
67

8+
from wct import batch_wct
79
from style_decorator import StyleDecorator
810

911
class AvatarNet(nn.Module):
10-
def __init__(self, layers):
12+
def __init__(self, layers=[1, 6, 11, 20]):
1113
super(AvatarNet, self).__init__()
12-
vgg = models.vgg19(pretrained=True).features
13-
14-
# get network layers
15-
self.encoders = get_encoder(vgg, layers)
16-
self.decoders = get_decoder(vgg, layers)
14+
self.encoder = Encoder(layers)
15+
self.decoder = Decoder(layers)
1716

1817
self.adain = AdaIN()
1918
self.decorator = StyleDecorator()
2019

21-
def forward(self, c, s, train_flag=False, style_strength=1.0, patch_size=3, patch_stride=1):
20+
def forward(self, content, styles, style_strength=1.0, patch_size=3, patch_stride=1, masks=None, interpolation_weights=None, preserve_color=False, train=False):
21+
if interpolation_weights is None:
22+
interpolation_weights = [1/len(styles)] * len(styles)
23+
if masks is None:
24+
masks = [1] * len(styles)
2225

2326
# encode content image
24-
for encoder in self.encoders:
25-
c = encoder(c)
26-
27-
# encode style image and extract multi-scale feature for AdaIN in decoder network
28-
features = []
29-
for encoder in self.encoders:
30-
s = encoder(s)
31-
features.append(s)
32-
33-
# delete last style feature
34-
del features[-1]
35-
36-
if not train_flag:
37-
c = self.decorator(c, s, style_strength, patch_size, patch_stride)
38-
39-
for decoder in self.decoders:
40-
c = decoder(c)
41-
if features:
42-
c = self.adain(c, features.pop())
27+
content_feature = self.encoder(content)
28+
style_features = []
29+
for style in styles:
30+
style_features.append(self.encoder(style))
31+
32+
if not train:
33+
transformed_feature = []
34+
for style_feature, interpolation_weight, mask in zip(style_features, interpolation_weights, masks):
35+
if isinstance(mask, torch.Tensor):
36+
b, c, h, w = content_feature[-1].size()
37+
mask = F.interpolate(mask, size=(h, w))
38+
transformed_feature.append(self.decorator(content_feature[-1], style_feature[-1], style_strength, patch_size, patch_stride) * interpolation_weight * mask)
39+
transformed_feature = sum(transformed_feature)
40+
41+
else:
42+
transformed_feature = content_feature[-1]
43+
44+
# re-ordering style features for transferring feature during decoding
45+
style_features = [style_feature[:-1][::-1] for style_feature in style_features]
46+
47+
stylized_image = self.decoder(transformed_feature, style_features, masks, interpolation_weights)
4348

44-
return c
49+
return stylized_image
4550

46-
# Adaptive Instance Normalization
47-
## ref: https://arxiv.org/abs/1703.06868
4851
class AdaIN(nn.Module):
49-
def __init__(self, ):
52+
def __init__(self):
5053
super(AdaIN, self).__init__()
54+
55+
def forward(self, content, style, style_strength=1.0, eps=1e-5):
56+
b, c, h, w = content.size()
5157

52-
def forward(self, x, t, eps=1e-5):
53-
b, c, h, w = x.size()
58+
content_std, content_mean = torch.std_mean(content.view(b, c, -1), dim=2, keepdim=True)
59+
style_std, style_mean = torch.std_mean(style.view(b, c, -1), dim=2, keepdim=True)
60+
61+
normalized_content = (content.view(b, c, -1) - content_mean)/(content_std+eps)
5462

55-
x_mean = torch.mean(x.view(b, c, h*w), dim=2, keepdim=True)
56-
x_std = torch.std(x.view(b, c, h*w), dim=2, keepdim=True)
63+
stylized_content = (normalized_content * style_std) + style_mean
64+
65+
output = (1-style_strength)*content + style_strength*stylized_content.view(b, c, h, w)
66+
return output
67+
68+
class Encoder(nn.Module):
69+
def __init__(self, layers=[1, 6, 11, 20]):
70+
super(Encoder, self).__init__()
71+
vgg = torchvision.models.vgg19(pretrained=True).features
5772

58-
t_b, t_c, t_h, t_w = t.size()
59-
t_mean = torch.mean(t.view(t_b, t_c, t_h*t_w), dim=2, keepdim=True)
60-
t_std = torch.std(t.view(t_b, t_c, t_h*t_w), dim=2, keepdim=True)
73+
self.encoder = nn.ModuleList()
74+
temp_seq = nn.Sequential()
75+
for i in range(max(layers)+1):
76+
temp_seq.add_module(str(i), vgg[i])
77+
if i in layers:
78+
self.encoder.append(temp_seq)
79+
temp_seq = nn.Sequential()
80+
81+
def forward(self, x):
82+
features = []
83+
for layer in self.encoder:
84+
x = layer(x)
85+
features.append(x)
86+
return features
87+
88+
class Decoder(nn.Module):
89+
def __init__(self, layers=[1, 6, 11, 20], transformers=[AdaIN(), AdaIN(), AdaIN(), None]):
90+
super(Decoder, self).__init__()
91+
vgg = torchvision.models.vgg19(pretrained=False).features
92+
self.transformers = transformers
6193

62-
x_ = ((x.view(b, c, h*w) - x_mean)/(x_std + eps))*t_std + t_mean
94+
self.decoder = nn.ModuleList()
95+
temp_seq = nn.Sequential()
96+
count = 0
97+
for i in range(max(layers)-1, -1, -1):
98+
if isinstance(vgg[i], nn.Conv2d):
99+
# get number of in/out channels
100+
out_channels = vgg[i].in_channels
101+
in_channels = vgg[i].out_channels
102+
kernel_size = vgg[i].kernel_size
103+
104+
# make a [reflection pad + convolution + relu] layer
105+
temp_seq.add_module(str(count), nn.ReflectionPad2d(padding=(1,1,1,1)))
106+
count += 1
107+
temp_seq.add_module(str(count), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size))
108+
count += 1
109+
temp_seq.add_module(str(count), nn.ReLU())
110+
count += 1
111+
112+
# change down-sampling(MaxPooling) --> upsampling
113+
elif isinstance(vgg[i], nn.MaxPool2d):
114+
temp_seq.add_module(str(count), nn.Upsample(scale_factor=2))
115+
count += 1
116+
117+
if i in layers:
118+
self.decoder.append(temp_seq)
119+
temp_seq = nn.Sequential()
120+
121+
# append last conv layers without ReLU activation
122+
self.decoder.append(temp_seq[:-1])
63123

64-
return x_.view(b, c, h, w)
65-
66-
67-
# get network from vgg feature network
68-
def get_encoder(vgg, layers):
69-
encoder = nn.ModuleList()
70-
temp_seq = nn.Sequential()
71-
for i in range(max(layers)+1):
72-
temp_seq.add_module(str(i), vgg[i])
73-
if i in layers:
74-
encoder.append(temp_seq)
75-
temp_seq = nn.Sequential()
76-
77-
return encoder
78-
79-
# get mirroed vgg feature network
80-
def get_decoder(vgg, layers):
81-
decoder = nn.ModuleList()
82-
temp_seq = nn.Sequential()
83-
count = 0
84-
for i in range(max(layers)-1, -1, -1):
85-
if isinstance(vgg[i], nn.Conv2d):
86-
# get number of in/out channels
87-
out_channels = vgg[i].in_channels
88-
in_channels = vgg[i].out_channels
89-
kernel_size = vgg[i].kernel_size
90-
91-
# make a [reflection pad + convolution + relu] layer
92-
temp_seq.add_module(str(count), nn.ReflectionPad2d(padding=(1,1,1,1)))
93-
count += 1
94-
temp_seq.add_module(str(count), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size))
95-
count += 1
96-
temp_seq.add_module(str(count), nn.ReLU())
97-
count += 1
98-
99-
# change down-sampling(MaxPooling) --> upsampling
100-
elif isinstance(vgg[i], nn.MaxPool2d):
101-
temp_seq.add_module(str(count), nn.Upsample(scale_factor=2))
102-
count += 1
103-
104-
if i in layers:
105-
decoder.append(temp_seq)
106-
temp_seq = nn.Sequential()
107-
108-
# append last conv layers without ReLU activation
109-
decoder.append(temp_seq[:-1])
110-
return decoder
124+
def forward(self, x, styles, masks=None, interpolation_weights=None):
125+
if interpolation_weights is None:
126+
interpolation_weights = [1/len(styles)] * len(styles)
127+
if masks is None:
128+
masks = [1] * len(styles)
129+
130+
y = x
131+
for i, layer in enumerate(self.decoder):
132+
y = layer(y)
133+
134+
if self.transformers[i]:
135+
transformed_feature = []
136+
for style, interpolation_weight, mask in zip(styles, interpolation_weights, masks):
137+
if isinstance(mask, torch.Tensor):
138+
b, c, h, w = y.size()
139+
mask = F.interpolate(mask, size=(h, w))
140+
transformed_feature.append(self.transformers[i](y, style[i]) * interpolation_weight * mask)
141+
y = sum(transformed_feature)
142+
143+
return y

0 commit comments

Comments
 (0)