@@ -2207,6 +2207,250 @@ def _PS(X, r, n_out_channel):
22072207 return net_new
22082208
22092209
2210+ ## Spatial Transformer Nets
2211+ def transformer (U , theta , out_size , name = 'SpatialTransformer2dAffine' , ** kwargs ):
2212+ """Spatial Transformer Layer for `2D Affine Transformation <https://en.wikipedia.org/wiki/Affine_transformation>`_
2213+ , see :class:`SpatialTransformer2dAffineLayer` class.
2214+
2215+ Parameters
2216+ ----------
2217+ U : float
2218+ The output of a convolutional net should have the
2219+ shape [num_batch, height, width, num_channels].
2220+ theta: float
2221+ The output of the localisation network should be [num_batch, 6], value range should be [0, 1] (via tanh).
2222+ out_size: tuple of two ints
2223+ The size of the output of the network (height, width)
2224+
2225+ References
2226+ ----------
2227+ - `Spatial Transformer Networks <https://arxiv.org/abs/1506.02025>`_
2228+ - `TensorFlow/Models <https://github.com/tensorflow/models/tree/master/transformer>`_
2229+
2230+ Notes
2231+ -----
2232+ - To initialize the network to the identity transform init.
2233+ >>> ``theta`` to
2234+ >>> identity = np.array([[1., 0., 0.],
2235+ ... [0., 1., 0.]])
2236+ >>> identity = identity.flatten()
2237+ >>> theta = tf.Variable(initial_value=identity)
2238+ """
2239+
2240+ def _repeat (x , n_repeats ):
2241+ with tf .variable_scope ('_repeat' ):
2242+ rep = tf .transpose (
2243+ tf .expand_dims (tf .ones (shape = tf .stack ([n_repeats , ])), 1 ), [1 , 0 ])
2244+ rep = tf .cast (rep , 'int32' )
2245+ x = tf .matmul (tf .reshape (x , (- 1 , 1 )), rep )
2246+ return tf .reshape (x , [- 1 ])
2247+
2248+ def _interpolate (im , x , y , out_size ):
2249+ with tf .variable_scope ('_interpolate' ):
2250+ # constants
2251+ num_batch = tf .shape (im )[0 ]
2252+ height = tf .shape (im )[1 ]
2253+ width = tf .shape (im )[2 ]
2254+ channels = tf .shape (im )[3 ]
2255+
2256+ x = tf .cast (x , 'float32' )
2257+ y = tf .cast (y , 'float32' )
2258+ height_f = tf .cast (height , 'float32' )
2259+ width_f = tf .cast (width , 'float32' )
2260+ out_height = out_size [0 ]
2261+ out_width = out_size [1 ]
2262+ zero = tf .zeros ([], dtype = 'int32' )
2263+ max_y = tf .cast (tf .shape (im )[1 ] - 1 , 'int32' )
2264+ max_x = tf .cast (tf .shape (im )[2 ] - 1 , 'int32' )
2265+
2266+ # scale indices from [-1, 1] to [0, width/height]
2267+ x = (x + 1.0 )* (width_f ) / 2.0
2268+ y = (y + 1.0 )* (height_f ) / 2.0
2269+
2270+ # do sampling
2271+ x0 = tf .cast (tf .floor (x ), 'int32' )
2272+ x1 = x0 + 1
2273+ y0 = tf .cast (tf .floor (y ), 'int32' )
2274+ y1 = y0 + 1
2275+
2276+ x0 = tf .clip_by_value (x0 , zero , max_x )
2277+ x1 = tf .clip_by_value (x1 , zero , max_x )
2278+ y0 = tf .clip_by_value (y0 , zero , max_y )
2279+ y1 = tf .clip_by_value (y1 , zero , max_y )
2280+ dim2 = width
2281+ dim1 = width * height
2282+ base = _repeat (tf .range (num_batch )* dim1 , out_height * out_width )
2283+ base_y0 = base + y0 * dim2
2284+ base_y1 = base + y1 * dim2
2285+ idx_a = base_y0 + x0
2286+ idx_b = base_y1 + x0
2287+ idx_c = base_y0 + x1
2288+ idx_d = base_y1 + x1
2289+
2290+ # use indices to lookup pixels in the flat image and restore
2291+ # channels dim
2292+ im_flat = tf .reshape (im , tf .stack ([- 1 , channels ]))
2293+ im_flat = tf .cast (im_flat , 'float32' )
2294+ Ia = tf .gather (im_flat , idx_a )
2295+ Ib = tf .gather (im_flat , idx_b )
2296+ Ic = tf .gather (im_flat , idx_c )
2297+ Id = tf .gather (im_flat , idx_d )
2298+
2299+ # and finally calculate interpolated values
2300+ x0_f = tf .cast (x0 , 'float32' )
2301+ x1_f = tf .cast (x1 , 'float32' )
2302+ y0_f = tf .cast (y0 , 'float32' )
2303+ y1_f = tf .cast (y1 , 'float32' )
2304+ wa = tf .expand_dims (((x1_f - x ) * (y1_f - y )), 1 )
2305+ wb = tf .expand_dims (((x1_f - x ) * (y - y0_f )), 1 )
2306+ wc = tf .expand_dims (((x - x0_f ) * (y1_f - y )), 1 )
2307+ wd = tf .expand_dims (((x - x0_f ) * (y - y0_f )), 1 )
2308+ output = tf .add_n ([wa * Ia , wb * Ib , wc * Ic , wd * Id ])
2309+ return output
2310+
2311+ def _meshgrid (height , width ):
2312+ with tf .variable_scope ('_meshgrid' ):
2313+ # This should be equivalent to:
2314+ # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
2315+ # np.linspace(-1, 1, height))
2316+ # ones = np.ones(np.prod(x_t.shape))
2317+ # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
2318+ x_t = tf .matmul (tf .ones (shape = tf .stack ([height , 1 ])),
2319+ tf .transpose (tf .expand_dims (tf .linspace (- 1.0 , 1.0 , width ), 1 ), [1 , 0 ]))
2320+ y_t = tf .matmul (tf .expand_dims (tf .linspace (- 1.0 , 1.0 , height ), 1 ),
2321+ tf .ones (shape = tf .stack ([1 , width ])))
2322+
2323+ x_t_flat = tf .reshape (x_t , (1 , - 1 ))
2324+ y_t_flat = tf .reshape (y_t , (1 , - 1 ))
2325+
2326+ ones = tf .ones_like (x_t_flat )
2327+ grid = tf .concat (axis = 0 , values = [x_t_flat , y_t_flat , ones ])
2328+ return grid
2329+
2330+ def _transform (theta , input_dim , out_size ):
2331+ with tf .variable_scope ('_transform' ):
2332+ num_batch = tf .shape (input_dim )[0 ]
2333+ height = tf .shape (input_dim )[1 ]
2334+ width = tf .shape (input_dim )[2 ]
2335+ num_channels = tf .shape (input_dim )[3 ]
2336+ theta = tf .reshape (theta , (- 1 , 2 , 3 ))
2337+ theta = tf .cast (theta , 'float32' )
2338+
2339+ # grid of (x_t, y_t, 1), eq (1) in ref [1]
2340+ height_f = tf .cast (height , 'float32' )
2341+ width_f = tf .cast (width , 'float32' )
2342+ out_height = out_size [0 ]
2343+ out_width = out_size [1 ]
2344+ grid = _meshgrid (out_height , out_width )
2345+ grid = tf .expand_dims (grid , 0 )
2346+ grid = tf .reshape (grid , [- 1 ])
2347+ grid = tf .tile (grid , tf .stack ([num_batch ]))
2348+ grid = tf .reshape (grid , tf .stack ([num_batch , 3 , - 1 ]))
2349+
2350+ # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
2351+ T_g = tf .matmul (theta , grid )
2352+ x_s = tf .slice (T_g , [0 , 0 , 0 ], [- 1 , 1 , - 1 ])
2353+ y_s = tf .slice (T_g , [0 , 1 , 0 ], [- 1 , 1 , - 1 ])
2354+ x_s_flat = tf .reshape (x_s , [- 1 ])
2355+ y_s_flat = tf .reshape (y_s , [- 1 ])
2356+
2357+ input_transformed = _interpolate (
2358+ input_dim , x_s_flat , y_s_flat ,
2359+ out_size )
2360+
2361+ output = tf .reshape (
2362+ input_transformed , tf .stack ([num_batch , out_height , out_width , num_channels ]))
2363+ return output
2364+
2365+ with tf .variable_scope (name ):
2366+ output = _transform (theta , U , out_size )
2367+ return output
2368+
2369+ def batch_transformer (U , thetas , out_size , name = 'BatchSpatialTransformer2dAffine' ):
2370+ """Batch Spatial Transformer function for `2D Affine Transformation <https://en.wikipedia.org/wiki/Affine_transformation>`_.
2371+
2372+ Parameters
2373+ ----------
2374+ U : float
2375+ tensor of inputs [batch, height, width, num_channels]
2376+ thetas : float
2377+ a set of transformations for each input [batch, num_transforms, 6]
2378+ out_size : int
2379+ the size of the output [out_height, out_width]
2380+ Returns: float
2381+ Tensor of size [batch * num_transforms, out_height, out_width, num_channels]
2382+ """
2383+ with tf .variable_scope (name ):
2384+ num_batch , num_transforms = map (int , thetas .get_shape ().as_list ()[:2 ])
2385+ indices = [[i ]* num_transforms for i in xrange (num_batch )]
2386+ input_repeated = tf .gather (U , tf .reshape (indices , [- 1 ]))
2387+ return transformer (input_repeated , thetas , out_size )
2388+
2389+ class SpatialTransformer2dAffineLayer (Layer ):
2390+ """The :class:`SpatialTransformer2dAffineLayer` class is a
2391+ `Spatial Transformer Layer <https://arxiv.org/abs/1506.02025>`_ for
2392+ `2D Affine Transformation <https://en.wikipedia.org/wiki/Affine_transformation>`_.
2393+
2394+ Parameters
2395+ -----------
2396+ layer : a layer class with 4-D Tensor of shape [batch, height, width, channels]
2397+ theta_layer : a layer class for the localisation network.
2398+ This layer will use a :class:`DenseLayer` to make the theta size to [batch, 6], value range to [0, 1] (via tanh).
2399+ out_size : tuple of two ints.
2400+ The size of the output of the network (height, width)
2401+
2402+ References
2403+ -----------
2404+ - `Spatial Transformer Networks <https://arxiv.org/abs/1506.02025>`_
2405+ - `TensorFlow/Models <https://github.com/tensorflow/models/tree/master/transformer>`_
2406+ """
2407+ def __init__ (
2408+ self ,
2409+ layer = None ,
2410+ theta_layer = None ,
2411+ out_size = [40 , 40 ],
2412+ name = 'sapatial_trans_2d_affine' ,
2413+ ):
2414+ Layer .__init__ (self , name = name )
2415+ self .inputs = layer .outputs
2416+ self .theta_layer = theta_layer
2417+ print (" [TL] SpatialTransformer2dAffineLayer %s: in_size:%s out_size:%s" %
2418+ (name , self .inputs .get_shape ().as_list (), out_size ))
2419+
2420+ with tf .variable_scope (name ) as vs :
2421+ ## 1. make the localisation network to [batch, 6] via Flatten and Dense.
2422+ if self .theta_layer .outputs .get_shape ().ndims > 2 :
2423+ self .theta_layer .outputs = flatten_reshape (self .theta_layer .outputs , 'flatten' )
2424+ ## 2. To initialize the network to the identity transform init.
2425+ # 2.1 W
2426+ n_in = int (self .theta_layer .outputs .get_shape ()[- 1 ])
2427+ shape = (n_in , 6 )
2428+ W = tf .get_variable (name = 'W' , initializer = tf .zeros (shape ))
2429+ # 2.2 b
2430+ identity = tf .constant (np .array ([[1. , 0 , 0 ], [0 , 1. , 0 ]]).astype ('float32' ).flatten ())
2431+ b = tf .get_variable (name = 'b' , initializer = identity )
2432+ # 2.3 transformation matrix
2433+ self .theta = tf .nn .tanh (tf .matmul (self .theta_layer .outputs , W ) + b )
2434+ ## 3. Spatial Transformer Sampling
2435+ self .outputs = transformer (self .inputs , self .theta , out_size = out_size )
2436+ ## 4. Get all parameters
2437+ variables = tf .get_collection (TF_GRAPHKEYS_VARIABLES , scope = vs .name )
2438+
2439+ ## fixed
2440+ self .all_layers = list (layer .all_layers )
2441+ self .all_params = list (layer .all_params )
2442+ self .all_drop = dict (layer .all_drop )
2443+
2444+ ## theta_layer
2445+ self .all_layers .extend (theta_layer .all_layers )
2446+ self .all_params .extend (theta_layer .all_params )
2447+ self .all_drop .update (theta_layer .all_drop )
2448+
2449+ ## this layer
2450+ self .all_layers .extend ( [self .outputs ] )
2451+ self .all_params .extend ( variables )
2452+
2453+
22102454# ## Normalization layer
22112455class LocalResponseNormLayer (Layer ):
22122456 """The :class:`LocalResponseNormLayer` class is for Local Response Normalization, see ``tf.nn.local_response_normalization`` or ``tf.nn.lrn`` for new TF version.
0 commit comments