Skip to content

Commit 9557ff7

Browse files
authored
Fix docs (#515)
* update super resolution from function to class * fix docs * fix docs
1 parent 9693e3d commit 9557ff7

File tree

3 files changed

+54
-72
lines changed

3 files changed

+54
-72
lines changed

docs/modules/layers.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,11 +519,11 @@ Super-Resolution layer
519519

520520
1D Subpixel Convolution
521521
^^^^^^^^^^^^^^^^^^^^^^^^^^
522-
.. autofunction:: SubpixelConv1d
522+
.. autoclass:: SubpixelConv1d
523523

524524
2D Subpixel Convolution
525525
^^^^^^^^^^^^^^^^^^^^^^^^^^
526-
.. autofunction:: SubpixelConv2d
526+
.. autoclass:: SubpixelConv2d
527527

528528

529529
Spatial Transformer

tensorlayer/layers/core.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,13 +348,16 @@ class Layer(object):
348348
349349
Examples
350350
---------
351-
Define model
351+
352+
- Define model
353+
352354
>>> x = tf.placeholder("float32", [None, 100])
353355
>>> n = tl.layers.InputLayer(x, name='in')
354356
>>> n = tl.layers.DenseLayer(n, 80, name='d1')
355357
>>> n = tl.layers.DenseLayer(n, 80, name='d2')
356358
357-
Get information
359+
- Get information
360+
358361
>>> print(n)
359362
... Last layer is: DenseLayer (d2) [None, 80]
360363
>>> n.print_layers()
@@ -369,12 +372,14 @@ class Layer(object):
369372
>>> n.count_params()
370373
... 14560
371374
372-
Slicing the outputs
375+
- Slicing the outputs
376+
373377
>>> n2 = n[:, :30]
374378
>>> print(n2)
375379
... Last layer is: Layer (d2) [None, 30]
376380
377-
Iterating the outputs
381+
- Iterating the outputs
382+
378383
>>> for l in n:
379384
>>> print(l)
380385
... Tensor("d1/Identity:0", shape=(?, 80), dtype=float32)

tensorlayer/layers/super_resolution.py

Lines changed: 43 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
]
1414

1515

16-
@deprecated_alias(net='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
17-
def subpixel_conv2d(prev_layer, scale=2, n_out_channel=None, act=tf.identity, name='subpixel_conv2d'):
16+
class SubpixelConv2d(Layer):
1817
"""It is a 2D sub-pixel up-sampling layer, usually be used
1918
for Super-Resolution applications, see `SRGAN <https://github.com/zsdonghao/SRGAN/>`__ for example.
2019
@@ -33,11 +32,6 @@ def subpixel_conv2d(prev_layer, scale=2, n_out_channel=None, act=tf.identity, na
3332
name : str
3433
A unique layer name.
3534
36-
Returns
37-
-------
38-
:class:`Layer`
39-
A 2D sub-pixel up-sampling layer
40-
4135
Examples
4236
---------
4337
>>> # examples here just want to tell you how to set the n_out_channel.
@@ -71,51 +65,41 @@ def subpixel_conv2d(prev_layer, scale=2, n_out_channel=None, act=tf.identity, na
7165
7266
"""
7367
# github/Tetrachrome/subpixel https://github.com/Tetrachrome/subpixel/blob/master/subpixel.py
68+
@deprecated_alias(net='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
69+
def __init__(self, prev_layer, scale=2, n_out_channel=None, act=tf.identity, name='subpixel_conv2d'):
70+
_err_log = "SubpixelConv2d: The number of input channels == (scale x scale) x The number of output channels"
7471

75-
_err_log = "SubpixelConv2d: The number of input channels == (scale x scale) x The number of output channels"
76-
77-
# scope_name = tf.get_variable_scope().name
78-
# if scope_name:
79-
# whole_name = scope_name + '/' + name
80-
# else:
81-
# whole_name = name
82-
83-
def _PS(X, r, n_out_channels):
84-
if n_out_channels >= 1:
85-
assert int(X.get_shape()[-1]) == (r**2) * n_out_channels, _err_log
72+
super(SubpixelConv2d, self).__init__(prev_layer=prev_layer, name=name)
73+
logging.info("SubpixelConv2d %s: scale: %d n_out_channel: %s act: %s" % (name, scale, n_out_channel, act.__name__))
8674

87-
# bsize, a, b, c = X.get_shape().as_list()
88-
# bsize = tf.shape(X)[0] # Handling Dimension(None) type for undefined batch dim
89-
# Xs=tf.split(X,r,3) #b*h*w*r*r
90-
# Xr=tf.concat(Xs,2) #b*h*(r*w)*r
91-
# X=tf.reshape(Xr,(bsize,r*a,r*b,n_out_channel)) # b*(r*h)*(r*w)*c
75+
def _PS(X, r, n_out_channels):
76+
if n_out_channels >= 1:
77+
if int(X.get_shape()[-1]) != (r**2) * n_out_channels:
78+
raise Exception(_err_log)
79+
# bsize, a, b, c = X.get_shape().as_list()
80+
# bsize = tf.shape(X)[0] # Handling Dimension(None) type for undefined batch dim
81+
# Xs=tf.split(X,r,3) #b*h*w*r*r
82+
# Xr=tf.concat(Xs,2) #b*h*(r*w)*r
83+
# X=tf.reshape(Xr,(bsize,r*a,r*b,n_out_channel)) # b*(r*h)*(r*w)*c
9284

93-
X = tf.depth_to_space(X, r)
94-
else:
95-
logging.info(_err_log)
96-
return X
85+
X = tf.depth_to_space(X, r)
86+
else:
87+
logging.info(_err_log)
88+
return X
9789

98-
inputs = prev_layer.outputs
99-
if n_out_channel is None:
100-
assert int(inputs.get_shape()[-1]) / (scale**2) % 1 == 0, _err_log
101-
n_out_channel = int(int(inputs.get_shape()[-1]) / (scale**2))
90+
self.inputs = prev_layer.outputs
91+
if n_out_channel is None:
92+
if int(self.inputs.get_shape()[-1]) / (scale**2) % 1 != 0:
93+
raise Exception(_err_log)
94+
n_out_channel = int(int(self.inputs.get_shape()[-1]) / (scale**2))
10295

103-
logging.info("SubpixelConv2d %s: scale: %d n_out_channel: %s act: %s" % (name, scale, n_out_channel, act.__name__))
96+
with tf.variable_scope(name):
97+
self.outputs = act(_PS(self.inputs, r=scale, n_out_channels=n_out_channel))
10498

105-
net_new = Layer(prev_layer=prev_layer, name=name)
106-
# with tf.name_scope(name):
107-
with tf.variable_scope(name):
108-
net_new.outputs = act(_PS(inputs, r=scale, n_out_channels=n_out_channel))
99+
self.all_layers.append(self.outputs)
109100

110-
# net_new.all_layers = list(prev_layer.all_layers)
111-
# net_new.all_params = list(prev_layer.all_params)
112-
# net_new.all_drop = dict(prev_layer.all_drop)
113-
net_new.all_layers.append(net_new.outputs)
114-
return net_new
115101

116-
117-
@deprecated_alias(net='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
118-
def subpixel_conv1d(prev_layer, scale=2, act=tf.identity, name='subpixel_conv1d'):
102+
class SubpixelConv1d(Layer):
119103
"""It is a 1D sub-pixel up-sampling layer.
120104
121105
Calls a TensorFlow function that directly implements this functionality.
@@ -132,11 +116,6 @@ def subpixel_conv1d(prev_layer, scale=2, act=tf.identity, name='subpixel_conv1d'
132116
name : str
133117
A unique layer name.
134118
135-
Returns
136-
-------
137-
:class:`Layer`
138-
A 1D sub-pixel up-sampling layer
139-
140119
Examples
141120
----------
142121
>>> t_signal = tf.placeholder('float32', [10, 100, 4], name='x')
@@ -151,26 +130,24 @@ def subpixel_conv1d(prev_layer, scale=2, act=tf.identity, name='subpixel_conv1d'
151130
152131
"""
153132

154-
def _PS(I, r):
155-
X = tf.transpose(I, [2, 1, 0]) # (r, w, b)
156-
X = tf.batch_to_space_nd(X, [r], [[0, 0]]) # (1, r*w, b)
157-
X = tf.transpose(X, [2, 1, 0])
158-
return X
133+
@deprecated_alias(net='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
134+
def __init__(self, prev_layer, scale=2, act=tf.identity, name='subpixel_conv1d'):
135+
def _PS(I, r):
136+
X = tf.transpose(I, [2, 1, 0]) # (r, w, b)
137+
X = tf.batch_to_space_nd(X, [r], [[0, 0]]) # (1, r*w, b)
138+
X = tf.transpose(X, [2, 1, 0])
139+
return X
159140

160-
logging.info("SubpixelConv1d %s: scale: %d act: %s" % (name, scale, act.__name__))
141+
super(SubpixelConv1d, self).__init__(prev_layer=prev_layer, name=name)
142+
logging.info("SubpixelConv1d %s: scale: %d act: %s" % (name, scale, act.__name__))
161143

162-
inputs = prev_layer.outputs
163-
net_new = Layer(prev_layer=prev_layer, name=name)
164-
with tf.name_scope(name):
165-
net_new.outputs = act(_PS(inputs, r=scale))
144+
self.inputs = prev_layer.outputs
145+
with tf.name_scope(name):
146+
self.outputs = act(_PS(self.inputs, r=scale))
166147

167-
# net_new.all_layers = list(prev_layer.all_layers)
168-
# net_new.all_params = list(prev_layer.all_params)
169-
# net_new.all_drop = dict(prev_layer.all_drop)
170-
net_new.all_layers.append(net_new.outputs)
171-
return net_new
148+
self.all_layers.append(self.outputs)
172149

173150

174151
# Alias
175-
SubpixelConv2d = subpixel_conv2d
176-
SubpixelConv1d = subpixel_conv1d
152+
# SubpixelConv2d = subpixel_conv2d
153+
# SubpixelConv1d = subpixel_conv1d

0 commit comments

Comments
 (0)