@@ -2167,31 +2167,14 @@ def SubpixelConv2d(net, scale=2, n_out_channel=None, act=tf.identity, name='subp
21672167 if scope_name :
21682168 name = scope_name + '/' + name
21692169
2170- def _phase_shift (I , r ):
2171- if tf .__version__ < '1.0' :
2172- raise Exception ("Only support TF1.0+" )
2173- bsize , a , b , c = I .get_shape ().as_list ()
2174- bsize = tf .shape (I )[0 ] # Handling Dimension(None) type for undefined batch dim
2175- X = tf .reshape (I , (bsize , a , b , r , r ))
2176- X = tf .transpose (X , (0 , 1 , 2 , 4 , 3 )) # bsize, a, b, 1, 1 # tf 0.12
2177- # X = tf.split(1, a, X) # a, [bsize, b, r, r] # tf 0.12
2178- X = tf .split (X , a , 1 )
2179- # X = tf.concat(2, [tf.squeeze(x, axis=1) for x in X]) # bsize, b, a*r, r # tf 0.12
2180- X = tf .concat ([tf .squeeze (x , axis = 1 ) for x in X ], 2 )
2181- # X = tf.split(1, b, X) # b, [bsize, a*r, r] # tf 0.12
2182- X = tf .split (X , b , 1 )
2183- # X = tf.concat(2, [tf.squeeze(x, axis=1) for x in X]) # bsize, a*r, b*r # tf 0.12
2184- X = tf .concat ([tf .squeeze (x , axis = 1 ) for x in X ], 2 )
2185- return tf .reshape (X , (bsize , a * r , b * r , 1 ))
2186-
21872170 def _PS (X , r , n_out_channel ):
2188- if n_out_channel > 1 :
2171+ if n_out_channel >= 1 :
21892172 assert int (X .get_shape ()[- 1 ]) == (r ** 2 ) * n_out_channel , _err_log
2190- Xc = tf . split ( X , n_out_channel , 3 )
2191- X = tf .concat ([ _phase_shift ( x , r ) for x in Xc ], 3 )
2192- elif n_out_channel == 1 :
2193- assert int ( X . get_shape ()[ - 1 ]) == ( r ** 2 ), _err_log
2194- X = _phase_shift ( X , r )
2173+ bsize , a , b , c = X . get_shape (). as_list ( )
2174+ bsize = tf .shape ( X )[ 0 ] # Handling Dimension(None) type for undefined batch dim
2175+ Xs = tf . split ( X , r , 3 ) #b*h*w*r*r
2176+ Xr = tf . concat ( Xs , 2 ) #b*h*(r*w)*r
2177+ X = tf . reshape ( Xr ,( b , r * h , r * w , c )) # b*(r*h)*(r*w)*c
21952178 else :
21962179 print (_err_log )
21972180 return X
0 commit comments