Skip to content

Commit d90ed50

Browse files
committed
check new subpixel conv2d
1 parent a31ead9 commit d90ed50

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

tensorlayer/layers.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2198,6 +2198,110 @@ def _PS(X, r, n_out_channel):
21982198
net_new.all_layers.extend( [net_new.outputs] )
21992199
return net_new
22002200

2201+
def SubpixelConv2d_old(net, scale=2, n_out_channel=None, act=tf.identity, name='subpixel_conv2d'):
2202+
"""The :class:`SubpixelConv2d` class is a sub-pixel 2d convolutional ayer, usually be used
2203+
for Super-Resolution applications, `example code <https://github.com/zsdonghao/SRGAN/>`_.
2204+
2205+
Parameters
2206+
------------
2207+
net : TensorLayer layer.
2208+
scale : int, upscaling ratio, a wrong setting will lead to Dimension size error.
2209+
n_out_channel : int or None, the number of output channels.
2210+
Note that, the number of input channels == (scale x scale) x The number of output channels.
2211+
If None, automatically set n_out_channel == the number of input channels / (scale x scale).
2212+
act : activation function.
2213+
name : string.
2214+
An optional name to attach to this layer.
2215+
2216+
Examples
2217+
---------
2218+
>>> # examples here just want to tell you how to set the n_out_channel.
2219+
>>> x = np.random.rand(2, 16, 16, 4)
2220+
>>> X = tf.placeholder("float32", shape=(2, 16, 16, 4), name="X")
2221+
>>> net = InputLayer(X, name='input')
2222+
>>> net = SubpixelConv2d(net, scale=2, n_out_channel=1, name='subpixel_conv2d')
2223+
>>> y = sess.run(net.outputs, feed_dict={X: x})
2224+
>>> print(x.shape, y.shape)
2225+
... (2, 16, 16, 4) (2, 32, 32, 1)
2226+
>>>
2227+
>>> x = np.random.rand(2, 16, 16, 4*10)
2228+
>>> X = tf.placeholder("float32", shape=(2, 16, 16, 4*10), name="X")
2229+
>>> net = InputLayer(X, name='input2')
2230+
>>> net = SubpixelConv2d(net, scale=2, n_out_channel=10, name='subpixel_conv2d2')
2231+
>>> y = sess.run(net.outputs, feed_dict={X: x})
2232+
>>> print(x.shape, y.shape)
2233+
... (2, 16, 16, 40) (2, 32, 32, 10)
2234+
>>>
2235+
>>> x = np.random.rand(2, 16, 16, 25*10)
2236+
>>> X = tf.placeholder("float32", shape=(2, 16, 16, 25*10), name="X")
2237+
>>> net = InputLayer(X, name='input3')
2238+
>>> net = SubpixelConv2d(net, scale=5, n_out_channel=None, name='subpixel_conv2d3')
2239+
>>> y = sess.run(net.outputs, feed_dict={X: x})
2240+
>>> print(x.shape, y.shape)
2241+
... (2, 16, 16, 250) (2, 80, 80, 10)
2242+
2243+
References
2244+
------------
2245+
- `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network <https://arxiv.org/pdf/1609.05158.pdf>`_
2246+
"""
2247+
# github/Tetrachrome/subpixel https://github.com/Tetrachrome/subpixel/blob/master/subpixel.py
2248+
2249+
_err_log = "SubpixelConv2d: The number of input channels == (scale x scale) x The number of output channels"
2250+
2251+
scope_name = tf.get_variable_scope().name
2252+
if scope_name:
2253+
name = scope_name + '/' + name
2254+
2255+
def _phase_shift(I, r):
2256+
if tf.__version__ < '1.0':
2257+
raise Exception("Only support TF1.0+")
2258+
bsize, a, b, c = I.get_shape().as_list()
2259+
bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
2260+
X = tf.reshape(I, (bsize, a, b, r, r))
2261+
X = tf.transpose(X, (0, 1, 2, 4, 3)) # bsize, a, b, 1, 1 # tf 0.12
2262+
# X = tf.split(1, a, X) # a, [bsize, b, r, r] # tf 0.12
2263+
X = tf.split(X, a, 1)
2264+
# X = tf.concat(2, [tf.squeeze(x, axis=1) for x in X]) # bsize, b, a*r, r # tf 0.12
2265+
X = tf.concat([tf.squeeze(x, axis=1) for x in X], 2)
2266+
# X = tf.split(1, b, X) # b, [bsize, a*r, r] # tf 0.12
2267+
X = tf.split(X, b, 1)
2268+
# X = tf.concat(2, [tf.squeeze(x, axis=1) for x in X]) # bsize, a*r, b*r # tf 0.12
2269+
X = tf.concat([tf.squeeze(x, axis=1) for x in X], 2)
2270+
return tf.reshape(X, (bsize, a*r, b*r, 1))
2271+
2272+
def _PS(X, r, n_out_channel):
2273+
if n_out_channel > 1:
2274+
assert int(X.get_shape()[-1]) == (r ** 2) * n_out_channel, _err_log
2275+
Xc = tf.split(X, n_out_channel, 3)
2276+
X = tf.concat([_phase_shift(x, r) for x in Xc], 3)
2277+
elif n_out_channel == 1:
2278+
assert int(X.get_shape()[-1]) == (r ** 2), _err_log
2279+
X = _phase_shift(X, r)
2280+
else:
2281+
print(_err_log)
2282+
return X
2283+
2284+
inputs = net.outputs
2285+
2286+
if n_out_channel is None:
2287+
assert int(inputs.get_shape()[-1])/ (scale ** 2) % 1 == 0, _err_log
2288+
n_out_channel = int(int(inputs.get_shape()[-1])/ (scale ** 2))
2289+
2290+
print(" [TL] SubpixelConv2d %s: scale: %d n_out_channel: %s act: %s" % (name, scale, n_out_channel, act.__name__))
2291+
2292+
net_new = Layer(inputs, name=name)
2293+
# with tf.name_scope(name):
2294+
with tf.variable_scope(name) as vs:
2295+
net_new.outputs = act(_PS(inputs, r=scale, n_out_channel=n_out_channel))
2296+
2297+
net_new.all_layers = list(net.all_layers)
2298+
net_new.all_params = list(net.all_params)
2299+
net_new.all_drop = dict(net.all_drop)
2300+
net_new.all_layers.extend( [net_new.outputs] )
2301+
return net_new
2302+
2303+
2304+
22012305

22022306
## Spatial Transformer Nets
22032307
def transformer(U, theta, out_size, name='SpatialTransformer2dAffine', **kwargs):

0 commit comments

Comments
 (0)