Skip to content

Commit 3fe989a

Browse files
committed
release Spatial Transformer Net for 2D Affine Transformation
1 parent 9d95661 commit 3fe989a

File tree

2 files changed

+265
-0
lines changed

2 files changed

+265
-0
lines changed

docs/modules/layers.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ Layer list
286286

287287
SubpixelConv2d
288288

289+
SpatialTransformer2dAffineLayer
290+
transformer
291+
batch_transformer
292+
289293
BatchNormLayer
290294
LocalResponseNormLayer
291295

@@ -497,6 +501,23 @@ Super-resolution layer
497501
^^^^^^^^^^^^^^^^^^^^^^^^^^
498502
.. autofunction:: SubpixelConv2d
499503

504+
505+
Spatial Transformer
506+
-----------------------
507+
508+
2D Affine Transformation layer
509+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
510+
.. autoclass:: SpatialTransformer2dAffineLayer
511+
512+
2D Affine Transformation function
513+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
514+
.. autofunction:: transformer
515+
516+
Batch 2D Affine Transformation function
517+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
518+
.. autofunction:: batch_transformer
519+
520+
500521
Pooling layer
501522
----------------
502523

tensorlayer/layers.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2207,6 +2207,250 @@ def _PS(X, r, n_out_channel):
22072207
return net_new
22082208

22092209

2210+
## Spatial Transformer Nets
2211+
def transformer(U, theta, out_size, name='SpatialTransformer2dAffine', **kwargs):
2212+
"""Spatial Transformer Layer for `2D Affine Transformation <https://en.wikipedia.org/wiki/Affine_transformation>`_
2213+
, see :class:`SpatialTransformer2dAffineLayer` class.
2214+
2215+
Parameters
2216+
----------
2217+
U : float
2218+
The output of a convolutional net should have the
2219+
shape [num_batch, height, width, num_channels].
2220+
theta: float
2221+
The output of the localisation network should be [num_batch, 6], value range should be [0, 1] (via tanh).
2222+
out_size: tuple of two ints
2223+
The size of the output of the network (height, width)
2224+
2225+
References
2226+
----------
2227+
- `Spatial Transformer Networks <https://arxiv.org/abs/1506.02025>`_
2228+
- `TensorFlow/Models <https://github.com/tensorflow/models/tree/master/transformer>`_
2229+
2230+
Notes
2231+
-----
2232+
- To initialize the network to the identity transform init.
2233+
>>> ``theta`` to
2234+
>>> identity = np.array([[1., 0., 0.],
2235+
... [0., 1., 0.]])
2236+
>>> identity = identity.flatten()
2237+
>>> theta = tf.Variable(initial_value=identity)
2238+
"""
2239+
2240+
def _repeat(x, n_repeats):
2241+
with tf.variable_scope('_repeat'):
2242+
rep = tf.transpose(
2243+
tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0])
2244+
rep = tf.cast(rep, 'int32')
2245+
x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
2246+
return tf.reshape(x, [-1])
2247+
2248+
def _interpolate(im, x, y, out_size):
2249+
with tf.variable_scope('_interpolate'):
2250+
# constants
2251+
num_batch = tf.shape(im)[0]
2252+
height = tf.shape(im)[1]
2253+
width = tf.shape(im)[2]
2254+
channels = tf.shape(im)[3]
2255+
2256+
x = tf.cast(x, 'float32')
2257+
y = tf.cast(y, 'float32')
2258+
height_f = tf.cast(height, 'float32')
2259+
width_f = tf.cast(width, 'float32')
2260+
out_height = out_size[0]
2261+
out_width = out_size[1]
2262+
zero = tf.zeros([], dtype='int32')
2263+
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
2264+
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
2265+
2266+
# scale indices from [-1, 1] to [0, width/height]
2267+
x = (x + 1.0)*(width_f) / 2.0
2268+
y = (y + 1.0)*(height_f) / 2.0
2269+
2270+
# do sampling
2271+
x0 = tf.cast(tf.floor(x), 'int32')
2272+
x1 = x0 + 1
2273+
y0 = tf.cast(tf.floor(y), 'int32')
2274+
y1 = y0 + 1
2275+
2276+
x0 = tf.clip_by_value(x0, zero, max_x)
2277+
x1 = tf.clip_by_value(x1, zero, max_x)
2278+
y0 = tf.clip_by_value(y0, zero, max_y)
2279+
y1 = tf.clip_by_value(y1, zero, max_y)
2280+
dim2 = width
2281+
dim1 = width*height
2282+
base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
2283+
base_y0 = base + y0*dim2
2284+
base_y1 = base + y1*dim2
2285+
idx_a = base_y0 + x0
2286+
idx_b = base_y1 + x0
2287+
idx_c = base_y0 + x1
2288+
idx_d = base_y1 + x1
2289+
2290+
# use indices to lookup pixels in the flat image and restore
2291+
# channels dim
2292+
im_flat = tf.reshape(im, tf.stack([-1, channels]))
2293+
im_flat = tf.cast(im_flat, 'float32')
2294+
Ia = tf.gather(im_flat, idx_a)
2295+
Ib = tf.gather(im_flat, idx_b)
2296+
Ic = tf.gather(im_flat, idx_c)
2297+
Id = tf.gather(im_flat, idx_d)
2298+
2299+
# and finally calculate interpolated values
2300+
x0_f = tf.cast(x0, 'float32')
2301+
x1_f = tf.cast(x1, 'float32')
2302+
y0_f = tf.cast(y0, 'float32')
2303+
y1_f = tf.cast(y1, 'float32')
2304+
wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1)
2305+
wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1)
2306+
wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1)
2307+
wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1)
2308+
output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
2309+
return output
2310+
2311+
def _meshgrid(height, width):
2312+
with tf.variable_scope('_meshgrid'):
2313+
# This should be equivalent to:
2314+
# x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
2315+
# np.linspace(-1, 1, height))
2316+
# ones = np.ones(np.prod(x_t.shape))
2317+
# grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
2318+
x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
2319+
tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0]))
2320+
y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1),
2321+
tf.ones(shape=tf.stack([1, width])))
2322+
2323+
x_t_flat = tf.reshape(x_t, (1, -1))
2324+
y_t_flat = tf.reshape(y_t, (1, -1))
2325+
2326+
ones = tf.ones_like(x_t_flat)
2327+
grid = tf.concat(axis=0, values=[x_t_flat, y_t_flat, ones])
2328+
return grid
2329+
2330+
def _transform(theta, input_dim, out_size):
2331+
with tf.variable_scope('_transform'):
2332+
num_batch = tf.shape(input_dim)[0]
2333+
height = tf.shape(input_dim)[1]
2334+
width = tf.shape(input_dim)[2]
2335+
num_channels = tf.shape(input_dim)[3]
2336+
theta = tf.reshape(theta, (-1, 2, 3))
2337+
theta = tf.cast(theta, 'float32')
2338+
2339+
# grid of (x_t, y_t, 1), eq (1) in ref [1]
2340+
height_f = tf.cast(height, 'float32')
2341+
width_f = tf.cast(width, 'float32')
2342+
out_height = out_size[0]
2343+
out_width = out_size[1]
2344+
grid = _meshgrid(out_height, out_width)
2345+
grid = tf.expand_dims(grid, 0)
2346+
grid = tf.reshape(grid, [-1])
2347+
grid = tf.tile(grid, tf.stack([num_batch]))
2348+
grid = tf.reshape(grid, tf.stack([num_batch, 3, -1]))
2349+
2350+
# Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
2351+
T_g = tf.matmul(theta, grid)
2352+
x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
2353+
y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
2354+
x_s_flat = tf.reshape(x_s, [-1])
2355+
y_s_flat = tf.reshape(y_s, [-1])
2356+
2357+
input_transformed = _interpolate(
2358+
input_dim, x_s_flat, y_s_flat,
2359+
out_size)
2360+
2361+
output = tf.reshape(
2362+
input_transformed, tf.stack([num_batch, out_height, out_width, num_channels]))
2363+
return output
2364+
2365+
with tf.variable_scope(name):
2366+
output = _transform(theta, U, out_size)
2367+
return output
2368+
2369+
def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer2dAffine'):
2370+
"""Batch Spatial Transformer function for `2D Affine Transformation <https://en.wikipedia.org/wiki/Affine_transformation>`_.
2371+
2372+
Parameters
2373+
----------
2374+
U : float
2375+
tensor of inputs [batch, height, width, num_channels]
2376+
thetas : float
2377+
a set of transformations for each input [batch, num_transforms, 6]
2378+
out_size : int
2379+
the size of the output [out_height, out_width]
2380+
Returns: float
2381+
Tensor of size [batch * num_transforms, out_height, out_width, num_channels]
2382+
"""
2383+
with tf.variable_scope(name):
2384+
num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
2385+
indices = [[i]*num_transforms for i in xrange(num_batch)]
2386+
input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
2387+
return transformer(input_repeated, thetas, out_size)
2388+
2389+
class SpatialTransformer2dAffineLayer(Layer):
2390+
"""The :class:`SpatialTransformer2dAffineLayer` class is a
2391+
`Spatial Transformer Layer <https://arxiv.org/abs/1506.02025>`_ for
2392+
`2D Affine Transformation <https://en.wikipedia.org/wiki/Affine_transformation>`_.
2393+
2394+
Parameters
2395+
-----------
2396+
layer : a layer class with 4-D Tensor of shape [batch, height, width, channels]
2397+
theta_layer : a layer class for the localisation network.
2398+
This layer will use a :class:`DenseLayer` to make the theta size to [batch, 6], value range to [0, 1] (via tanh).
2399+
out_size : tuple of two ints.
2400+
The size of the output of the network (height, width)
2401+
2402+
References
2403+
-----------
2404+
- `Spatial Transformer Networks <https://arxiv.org/abs/1506.02025>`_
2405+
- `TensorFlow/Models <https://github.com/tensorflow/models/tree/master/transformer>`_
2406+
"""
2407+
def __init__(
2408+
self,
2409+
layer = None,
2410+
theta_layer = None,
2411+
out_size = [40, 40],
2412+
name ='sapatial_trans_2d_affine',
2413+
):
2414+
Layer.__init__(self, name=name)
2415+
self.inputs = layer.outputs
2416+
self.theta_layer = theta_layer
2417+
print(" [TL] SpatialTransformer2dAffineLayer %s: in_size:%s out_size:%s" %
2418+
(name, self.inputs.get_shape().as_list(), out_size))
2419+
2420+
with tf.variable_scope(name) as vs:
2421+
## 1. make the localisation network to [batch, 6] via Flatten and Dense.
2422+
if self.theta_layer.outputs.get_shape().ndims > 2:
2423+
self.theta_layer.outputs = flatten_reshape(self.theta_layer.outputs, 'flatten')
2424+
## 2. To initialize the network to the identity transform init.
2425+
# 2.1 W
2426+
n_in = int(self.theta_layer.outputs.get_shape()[-1])
2427+
shape = (n_in, 6)
2428+
W = tf.get_variable(name='W', initializer=tf.zeros(shape))
2429+
# 2.2 b
2430+
identity = tf.constant(np.array([[1., 0, 0], [0, 1., 0]]).astype('float32').flatten())
2431+
b = tf.get_variable(name='b', initializer=identity)
2432+
# 2.3 transformation matrix
2433+
self.theta = tf.nn.tanh(tf.matmul(self.theta_layer.outputs, W) + b)
2434+
## 3. Spatial Transformer Sampling
2435+
self.outputs = transformer(self.inputs, self.theta, out_size=out_size)
2436+
## 4. Get all parameters
2437+
variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
2438+
2439+
## fixed
2440+
self.all_layers = list(layer.all_layers)
2441+
self.all_params = list(layer.all_params)
2442+
self.all_drop = dict(layer.all_drop)
2443+
2444+
## theta_layer
2445+
self.all_layers.extend(theta_layer.all_layers)
2446+
self.all_params.extend(theta_layer.all_params)
2447+
self.all_drop.update(theta_layer.all_drop)
2448+
2449+
## this layer
2450+
self.all_layers.extend( [self.outputs] )
2451+
self.all_params.extend( variables )
2452+
2453+
22102454
# ## Normalization layer
22112455
class LocalResponseNormLayer(Layer):
22122456
"""The :class:`LocalResponseNormLayer` class is for Local Response Normalization, see ``tf.nn.local_response_normalization`` or ``tf.nn.lrn`` for new TF version.

0 commit comments

Comments
 (0)