1313]
1414
1515
16- @deprecated_alias (net = 'prev_layer' , end_support_version = 1.9 ) # TODO remove this line for the 1.9 release
17- def subpixel_conv2d (prev_layer , scale = 2 , n_out_channel = None , act = tf .identity , name = 'subpixel_conv2d' ):
16+ class SubpixelConv2d (Layer ):
1817 """It is a 2D sub-pixel up-sampling layer, usually be used
1918 for Super-Resolution applications, see `SRGAN <https://github.com/zsdonghao/SRGAN/>`__ for example.
2019
@@ -33,11 +32,6 @@ def subpixel_conv2d(prev_layer, scale=2, n_out_channel=None, act=tf.identity, na
3332 name : str
3433 A unique layer name.
3534
36- Returns
37- -------
38- :class:`Layer`
39- A 2D sub-pixel up-sampling layer
40-
4135 Examples
4236 ---------
4337 >>> # examples here just want to tell you how to set the n_out_channel.
@@ -71,51 +65,41 @@ def subpixel_conv2d(prev_layer, scale=2, n_out_channel=None, act=tf.identity, na
7165
7266 """
7367 # github/Tetrachrome/subpixel https://github.com/Tetrachrome/subpixel/blob/master/subpixel.py
68+ @deprecated_alias (net = 'prev_layer' , end_support_version = 1.9 ) # TODO remove this line for the 1.9 release
69+ def __init__ (self , prev_layer , scale = 2 , n_out_channel = None , act = tf .identity , name = 'subpixel_conv2d' ):
70+ _err_log = "SubpixelConv2d: The number of input channels == (scale x scale) x The number of output channels"
7471
75- _err_log = "SubpixelConv2d: The number of input channels == (scale x scale) x The number of output channels"
76-
77- # scope_name = tf.get_variable_scope().name
78- # if scope_name:
79- # whole_name = scope_name + '/' + name
80- # else:
81- # whole_name = name
82-
83- def _PS (X , r , n_out_channels ):
84- if n_out_channels >= 1 :
85- assert int (X .get_shape ()[- 1 ]) == (r ** 2 ) * n_out_channels , _err_log
72+ super (SubpixelConv2d , self ).__init__ (prev_layer = prev_layer , name = name )
73+ logging .info ("SubpixelConv2d %s: scale: %d n_out_channel: %s act: %s" % (name , scale , n_out_channel , act .__name__ ))
8674
87- # bsize, a, b, c = X.get_shape().as_list()
88- # bsize = tf.shape(X)[0] # Handling Dimension(None) type for undefined batch dim
89- # Xs=tf.split(X,r,3) #b*h*w*r*r
90- # Xr=tf.concat(Xs,2) #b*h*(r*w)*r
91- # X=tf.reshape(Xr,(bsize,r*a,r*b,n_out_channel)) # b*(r*h)*(r*w)*c
75+ def _PS (X , r , n_out_channels ):
76+ if n_out_channels >= 1 :
77+ if int (X .get_shape ()[- 1 ]) != (r ** 2 ) * n_out_channels :
78+ raise Exception (_err_log )
79+ # bsize, a, b, c = X.get_shape().as_list()
80+ # bsize = tf.shape(X)[0] # Handling Dimension(None) type for undefined batch dim
81+ # Xs=tf.split(X,r,3) #b*h*w*r*r
82+ # Xr=tf.concat(Xs,2) #b*h*(r*w)*r
83+ # X=tf.reshape(Xr,(bsize,r*a,r*b,n_out_channel)) # b*(r*h)*(r*w)*c
9284
93- X = tf .depth_to_space (X , r )
94- else :
95- logging .info (_err_log )
96- return X
85+ X = tf .depth_to_space (X , r )
86+ else :
87+ logging .info (_err_log )
88+ return X
9789
98- inputs = prev_layer .outputs
99- if n_out_channel is None :
100- assert int (inputs .get_shape ()[- 1 ]) / (scale ** 2 ) % 1 == 0 , _err_log
101- n_out_channel = int (int (inputs .get_shape ()[- 1 ]) / (scale ** 2 ))
90+ self .inputs = prev_layer .outputs
91+ if n_out_channel is None :
92+ if int (self .inputs .get_shape ()[- 1 ]) / (scale ** 2 ) % 1 != 0 :
93+ raise Exception (_err_log )
94+ n_out_channel = int (int (self .inputs .get_shape ()[- 1 ]) / (scale ** 2 ))
10295
103- logging .info ("SubpixelConv2d %s: scale: %d n_out_channel: %s act: %s" % (name , scale , n_out_channel , act .__name__ ))
96+ with tf .variable_scope (name ):
97+ self .outputs = act (_PS (self .inputs , r = scale , n_out_channels = n_out_channel ))
10498
105- net_new = Layer (prev_layer = prev_layer , name = name )
106- # with tf.name_scope(name):
107- with tf .variable_scope (name ):
108- net_new .outputs = act (_PS (inputs , r = scale , n_out_channels = n_out_channel ))
99+ self .all_layers .append (self .outputs )
109100
110- # net_new.all_layers = list(prev_layer.all_layers)
111- # net_new.all_params = list(prev_layer.all_params)
112- # net_new.all_drop = dict(prev_layer.all_drop)
113- net_new .all_layers .append (net_new .outputs )
114- return net_new
115101
116-
117- @deprecated_alias (net = 'prev_layer' , end_support_version = 1.9 ) # TODO remove this line for the 1.9 release
118- def subpixel_conv1d (prev_layer , scale = 2 , act = tf .identity , name = 'subpixel_conv1d' ):
102+ class SubpixelConv1d (Layer ):
119103 """It is a 1D sub-pixel up-sampling layer.
120104
121105 Calls a TensorFlow function that directly implements this functionality.
@@ -132,11 +116,6 @@ def subpixel_conv1d(prev_layer, scale=2, act=tf.identity, name='subpixel_conv1d'
132116 name : str
133117 A unique layer name.
134118
135- Returns
136- -------
137- :class:`Layer`
138- A 1D sub-pixel up-sampling layer
139-
140119 Examples
141120 ----------
142121 >>> t_signal = tf.placeholder('float32', [10, 100, 4], name='x')
@@ -151,26 +130,24 @@ def subpixel_conv1d(prev_layer, scale=2, act=tf.identity, name='subpixel_conv1d'
151130
152131 """
153132
154- def _PS (I , r ):
155- X = tf .transpose (I , [2 , 1 , 0 ]) # (r, w, b)
156- X = tf .batch_to_space_nd (X , [r ], [[0 , 0 ]]) # (1, r*w, b)
157- X = tf .transpose (X , [2 , 1 , 0 ])
158- return X
133+ @deprecated_alias (net = 'prev_layer' , end_support_version = 1.9 ) # TODO remove this line for the 1.9 release
134+ def __init__ (self , prev_layer , scale = 2 , act = tf .identity , name = 'subpixel_conv1d' ):
135+ def _PS (I , r ):
136+ X = tf .transpose (I , [2 , 1 , 0 ]) # (r, w, b)
137+ X = tf .batch_to_space_nd (X , [r ], [[0 , 0 ]]) # (1, r*w, b)
138+ X = tf .transpose (X , [2 , 1 , 0 ])
139+ return X
159140
160- logging .info ("SubpixelConv1d %s: scale: %d act: %s" % (name , scale , act .__name__ ))
141+ super (SubpixelConv1d , self ).__init__ (prev_layer = prev_layer , name = name )
142+ logging .info ("SubpixelConv1d %s: scale: %d act: %s" % (name , scale , act .__name__ ))
161143
162- inputs = prev_layer .outputs
163- net_new = Layer (prev_layer = prev_layer , name = name )
164- with tf .name_scope (name ):
165- net_new .outputs = act (_PS (inputs , r = scale ))
144+ self .inputs = prev_layer .outputs
145+ with tf .name_scope (name ):
146+ self .outputs = act (_PS (self .inputs , r = scale ))
166147
167- # net_new.all_layers = list(prev_layer.all_layers)
168- # net_new.all_params = list(prev_layer.all_params)
169- # net_new.all_drop = dict(prev_layer.all_drop)
170- net_new .all_layers .append (net_new .outputs )
171- return net_new
148+ self .all_layers .append (self .outputs )
172149
173150
174151# Alias
175- SubpixelConv2d = subpixel_conv2d
176- SubpixelConv1d = subpixel_conv1d
152+ # SubpixelConv2d = subpixel_conv2d
153+ # SubpixelConv1d = subpixel_conv1d
0 commit comments