Skip to content

Commit 8bd4f8a

Browse files
committed
Update whitening and coloring transform
1 parent 383bd78 commit 8bd4f8a

File tree

2 files changed

+44
-53
lines changed

2 files changed

+44
-53
lines changed

style_decorator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33

4-
from wct import batch_whitening, batch_coloring
4+
from wct import whitening, coloring
55

66
def extract_patches(feature, patch_size, stride, padding='zero'):
77
kh, kw = patch_size
@@ -136,17 +136,17 @@ def reassemble_feature(self, normalized_content_feature, normalized_style_featur
136136
def forward(self, content_feature, style_feature, style_strength=1.0, patch_size=3, patch_stride=1):
137137

138138
# 1-1. content feature projection
139-
normalized_content_feature = batch_whitening(content_feature)
139+
normalized_content_feature = whitening(content_feature)
140140

141141
# 1-2. style feature projection
142-
normalized_style_feature = batch_whitening(style_feature)
142+
normalized_style_feature = whitening(style_feature)
143143

144144
# 2. swap content and style features
145145
reassembled_feature = self.reassemble_feature(normalized_content_feature, normalized_style_feature,
146146
patch_size=patch_size, patch_stride=patch_stride)
147147

148148
# 3. reconstruction feature with style mean and covariance matrix
149-
stylized_feature = batch_coloring(reassembled_feature, style_feature)
149+
stylized_feature = coloring(reassembled_feature, style_feature)
150150

151151
# 4. content and style interpolation
152152
result_feature = (1-style_strength) * content_feature + style_strength * stylized_feature

wct.py

Lines changed: 40 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,52 @@
11
import torch
22

3-
def mean_covsqrt(f, inverse=False, eps=1e-5):
4-
c, h, w = f.size()
5-
6-
f_mean = torch.mean(f.view(c, h*w), dim=1, keepdim=True)
7-
f_zeromean = f.view(c, h*w) - f_mean
8-
f_cov = torch.mm(f_zeromean, f_zeromean.t())
9-
10-
u, s, v = torch.svd(f_cov)
11-
12-
k = c
13-
for i in range(c):
14-
if s[i] < eps:
15-
k = i
16-
break
17-
18-
if inverse:
19-
p = -0.5
20-
else:
21-
p = 0.5
22-
23-
f_covsqrt = torch.mm(torch.mm(v[:, 0:k], torch.diag(s[0:k].pow(p))), v[:, 0:k].t())
24-
return f_mean, f_covsqrt
3+
def covsqrt_mean(feature, inverse=False, tolerance=1e-14):
4+
# I referenced the default svd tolerance value in matlab.
255

26-
def whitening(f):
27-
c, h, w = f.size()
6+
b, c, h, w = feature.size()
287

29-
f_mean, f_inv_covsqrt = mean_covsqrt(f, inverse=True)
30-
31-
whiten_f = torch.mm(f_inv_covsqrt, f.view(c, h*w) - f_mean)
32-
33-
return whiten_f.view(c, h, w)
8+
mean = torch.mean(feature.view(b, c, -1), dim=2, keepdim=True)
9+
zeromean = feature.view(b, c, -1) - mean
10+
cov = torch.bmm(zeromean, zeromean.transpose(1, 2))
3411

35-
def coloring(f, t):
36-
f_c, f_h, f_w = f.size()
37-
t_c, t_h, t_w = t.size()
38-
39-
t_mean, t_covsqrt = mean_covsqrt(t)
12+
evals, evects = torch.symeig(cov, eigenvectors=True)
4013

41-
colored_f = torch.mm(t_covsqrt, f.view(f_c, f_h*f_w)) + t_mean
14+
p = 0.5
15+
if inverse:
16+
p *= -1
17+
18+
covsqrt = []
19+
for i in range(b):
20+
k = 0
21+
for j in range(c):
22+
if evals[i][j] > tolerance:
23+
k = j
24+
break
25+
covsqrt.append(torch.mm(evects[i][:, k:],
26+
torch.mm(evals[i][k:].pow(p).diag_embed(),
27+
evects[i][:, k:].t())).unsqueeze(0))
28+
covsqrt = torch.cat(covsqrt, dim=0)
29+
30+
u, s, v = torch.svd(cov)
31+
32+
return covsqrt, mean
4233

43-
return colored_f.view(f_c, f_h, f_w)
4434

45-
def batch_whitening(f):
46-
b, c, h, w = f.size()
35+
def whitening(feature):
36+
b, c, h, w = feature.size()
37+
38+
inv_covsqrt, mean = covsqrt_mean(feature, inverse=True)
4739

48-
whiten_f = torch.Tensor(b, c, h, w).type_as(f)
49-
for i, f_ in enumerate(torch.split(f, 1)):
50-
whiten_f[i] = whitening(f_.squeeze())
51-
52-
return whiten_f
40+
normalized_feature = torch.matmul(inv_covsqrt, feature.view(b, c, -1)-mean)
41+
42+
return normalized_feature.view(b, c, h, w)
5343

54-
def batch_coloring(f, t):
55-
b, c, h, w = f.size()
5644

57-
colored_f = torch.Tensor(b, c, h, w).type_as(f)
58-
for i, (f_, t_) in enumerate(zip(torch.split(f, 1), torch.split(t, 1))):
59-
colored_f[i] = coloring(f_.squeeze(), t_.squeeze())
45+
def coloring(feature, target):
46+
b, c, h, w = feature.size()
6047

61-
return colored_f
48+
covsqrt, mean = covsqrt_mean(target)
49+
50+
colored_feature = torch.matmul(covsqrt, feature.view(b, c, -1)) + mean
51+
52+
return colored_feature.view(b, c, h, w)

0 commit comments

Comments
 (0)