|
1 | 1 | import torch |
2 | 2 |
|
3 | | -def mean_covsqrt(f, inverse=False, eps=1e-5): |
4 | | - c, h, w = f.size() |
5 | | - |
6 | | - f_mean = torch.mean(f.view(c, h*w), dim=1, keepdim=True) |
7 | | - f_zeromean = f.view(c, h*w) - f_mean |
8 | | - f_cov = torch.mm(f_zeromean, f_zeromean.t()) |
9 | | - |
10 | | - u, s, v = torch.svd(f_cov) |
11 | | - |
12 | | - k = c |
13 | | - for i in range(c): |
14 | | - if s[i] < eps: |
15 | | - k = i |
16 | | - break |
17 | | - |
18 | | - if inverse: |
19 | | - p = -0.5 |
20 | | - else: |
21 | | - p = 0.5 |
22 | | - |
23 | | - f_covsqrt = torch.mm(torch.mm(v[:, 0:k], torch.diag(s[0:k].pow(p))), v[:, 0:k].t()) |
24 | | - return f_mean, f_covsqrt |
| 3 | +def covsqrt_mean(feature, inverse=False, tolerance=1e-14): |
| 4 | + # I referenced the default svd tolerance value in matlab. |
25 | 5 |
|
26 | | -def whitening(f): |
27 | | - c, h, w = f.size() |
| 6 | + b, c, h, w = feature.size() |
28 | 7 |
|
29 | | - f_mean, f_inv_covsqrt = mean_covsqrt(f, inverse=True) |
30 | | - |
31 | | - whiten_f = torch.mm(f_inv_covsqrt, f.view(c, h*w) - f_mean) |
32 | | - |
33 | | - return whiten_f.view(c, h, w) |
| 8 | + mean = torch.mean(feature.view(b, c, -1), dim=2, keepdim=True) |
| 9 | + zeromean = feature.view(b, c, -1) - mean |
| 10 | + cov = torch.bmm(zeromean, zeromean.transpose(1, 2)) |
34 | 11 |
|
35 | | -def coloring(f, t): |
36 | | - f_c, f_h, f_w = f.size() |
37 | | - t_c, t_h, t_w = t.size() |
38 | | - |
39 | | - t_mean, t_covsqrt = mean_covsqrt(t) |
| 12 | + evals, evects = torch.symeig(cov, eigenvectors=True) |
40 | 13 |
|
41 | | - colored_f = torch.mm(t_covsqrt, f.view(f_c, f_h*f_w)) + t_mean |
| 14 | + p = 0.5 |
| 15 | + if inverse: |
| 16 | + p *= -1 |
| 17 | + |
| 18 | + covsqrt = [] |
| 19 | + for i in range(b): |
| 20 | + k = 0 |
| 21 | + for j in range(c): |
| 22 | + if evals[i][j] > tolerance: |
| 23 | + k = j |
| 24 | + break |
| 25 | + covsqrt.append(torch.mm(evects[i][:, k:], |
| 26 | + torch.mm(evals[i][k:].pow(p).diag_embed(), |
| 27 | + evects[i][:, k:].t())).unsqueeze(0)) |
| 28 | + covsqrt = torch.cat(covsqrt, dim=0) |
| 29 | + |
| 30 | + u, s, v = torch.svd(cov) |
| 31 | + |
| 32 | + return covsqrt, mean |
42 | 33 |
|
43 | | - return colored_f.view(f_c, f_h, f_w) |
44 | 34 |
|
45 | | -def batch_whitening(f): |
46 | | - b, c, h, w = f.size() |
| 35 | +def whitening(feature): |
| 36 | + b, c, h, w = feature.size() |
| 37 | + |
| 38 | + inv_covsqrt, mean = covsqrt_mean(feature, inverse=True) |
47 | 39 |
|
48 | | - whiten_f = torch.Tensor(b, c, h, w).type_as(f) |
49 | | - for i, f_ in enumerate(torch.split(f, 1)): |
50 | | - whiten_f[i] = whitening(f_.squeeze()) |
51 | | - |
52 | | - return whiten_f |
| 40 | + normalized_feature = torch.matmul(inv_covsqrt, feature.view(b, c, -1)-mean) |
| 41 | + |
| 42 | + return normalized_feature.view(b, c, h, w) |
53 | 43 |
|
54 | | -def batch_coloring(f, t): |
55 | | - b, c, h, w = f.size() |
56 | 44 |
|
57 | | - colored_f = torch.Tensor(b, c, h, w).type_as(f) |
58 | | - for i, (f_, t_) in enumerate(zip(torch.split(f, 1), torch.split(t, 1))): |
59 | | - colored_f[i] = coloring(f_.squeeze(), t_.squeeze()) |
| 45 | +def coloring(feature, target): |
| 46 | + b, c, h, w = feature.size() |
60 | 47 |
|
61 | | - return colored_f |
| 48 | + covsqrt, mean = covsqrt_mean(target) |
| 49 | + |
| 50 | + colored_feature = torch.matmul(covsqrt, feature.view(b, c, -1)) + mean |
| 51 | + |
| 52 | + return colored_feature.view(b, c, h, w) |
0 commit comments