Skip to content

Commit fc18318

Browse files
authored
Simplify SubpixelConv2d
Reduce the number of ops to do subpixel convolution
1 parent f4ab8ff commit fc18318

File tree

1 file changed

+6
-23
lines changed

1 file changed

+6
-23
lines changed

tensorlayer/layers.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2167,31 +2167,14 @@ def SubpixelConv2d(net, scale=2, n_out_channel=None, act=tf.identity, name='subp
21672167
if scope_name:
21682168
name = scope_name + '/' + name
21692169

2170-
def _phase_shift(I, r):
2171-
if tf.__version__ < '1.0':
2172-
raise Exception("Only support TF1.0+")
2173-
bsize, a, b, c = I.get_shape().as_list()
2174-
bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
2175-
X = tf.reshape(I, (bsize, a, b, r, r))
2176-
X = tf.transpose(X, (0, 1, 2, 4, 3)) # bsize, a, b, 1, 1 # tf 0.12
2177-
# X = tf.split(1, a, X) # a, [bsize, b, r, r] # tf 0.12
2178-
X = tf.split(X, a, 1)
2179-
# X = tf.concat(2, [tf.squeeze(x, axis=1) for x in X]) # bsize, b, a*r, r # tf 0.12
2180-
X = tf.concat([tf.squeeze(x, axis=1) for x in X], 2)
2181-
# X = tf.split(1, b, X) # b, [bsize, a*r, r] # tf 0.12
2182-
X = tf.split(X, b, 1)
2183-
# X = tf.concat(2, [tf.squeeze(x, axis=1) for x in X]) # bsize, a*r, b*r # tf 0.12
2184-
X = tf.concat([tf.squeeze(x, axis=1) for x in X], 2)
2185-
return tf.reshape(X, (bsize, a*r, b*r, 1))
2186-
21872170
def _PS(X, r, n_out_channel):
2188-
if n_out_channel > 1:
2171+
if n_out_channel >= 1:
21892172
assert int(X.get_shape()[-1]) == (r ** 2) * n_out_channel, _err_log
2190-
Xc = tf.split(X, n_out_channel, 3)
2191-
X = tf.concat([_phase_shift(x, r) for x in Xc], 3)
2192-
elif n_out_channel == 1:
2193-
assert int(X.get_shape()[-1]) == (r ** 2), _err_log
2194-
X = _phase_shift(X, r)
2173+
bsize, a, b, c = X.get_shape().as_list()
2174+
bsize = tf.shape(X)[0] # Handling Dimension(None) type for undefined batch dim
2175+
Xs=tf.split(X,r,3) #b*h*w*r*r
2176+
Xr=tf.concat(Xs,2) #b*h*(r*w)*r
2177+
X=tf.reshape(Xr,(b,r*h,r*w,c)) # b*(r*h)*(r*w)*c
21952178
else:
21962179
print(_err_log)
21972180
return X

0 commit comments

Comments
 (0)