Skip to content

Commit 5dda663

Browse files
zsdonghaoDEKHTIARJonathan
authored andcommitted
release seperableconv1d / update deconv3d (#526)
* release seperableconv1d / update deconv3d * add test
1 parent 7d9506d commit 5dda663

File tree

3 files changed

+145
-122
lines changed

3 files changed

+145
-122
lines changed

docs/modules/layers.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ Layer list
260260
DeConv2d
261261
DeConv3d
262262
DepthwiseConv2d
263+
SeparableConv1d
263264
SeparableConv2d
264265
DeformableConv2d
265266
GroupConv2d
@@ -502,6 +503,10 @@ APIs may better for you.
502503
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
503504
.. autoclass:: DepthwiseConv2d
504505

506+
1D Depthwise Separable Conv
507+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
508+
.. autoclass:: SeparableConv1d
509+
505510
2D Depthwise Separable Conv
506511
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
507512
.. autoclass:: SeparableConv2d

tensorlayer/layers/convolution.py

Lines changed: 121 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
'DeConv2d',
2525
'DeConv3d',
2626
'DepthwiseConv2d',
27+
'SeparableConv1d',
2728
'SeparableConv2d',
2829
'GroupConv2d',
2930
]
@@ -1152,115 +1153,6 @@ def __init__(
11521153
self.all_params.append(filters)
11531154

11541155

1155-
class _SeparableConv2dLayer(Layer): # TODO
1156-
"""The :class:`SeparableConv2dLayer` class is 2D convolution with separable filters, see `tf.layers.separable_conv2d <https://www.tensorflow.org/api_docs/python/tf/layers/separable_conv2d>`__.
1157-
1158-
This layer has not been fully tested yet.
1159-
1160-
Parameters
1161-
----------
1162-
prev_layer : :class:`Layer`
1163-
Previous layer with a 4D output tensor in the shape of [batch, height, width, channels].
1164-
n_filter : int
1165-
The number of filters.
1166-
filter_size : tuple of int
1167-
The filter size (height, width).
1168-
strides : tuple of int
1169-
The strides (height, width).
1170-
This can be a single integer if you want to specify the same value for all spatial dimensions.
1171-
Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1.
1172-
padding : str
1173-
The type of padding algorithm: "SAME" or "VALID"
1174-
data_format : str
1175-
One of channels_last (Default) or channels_first.
1176-
The order must match the input dimensions.
1177-
channels_last corresponds to inputs with shapedata_format = 'NWHC' (batch, width, height, channels) while
1178-
channels_first corresponds to inputs with shape [batch, channels, width, height].
1179-
dilation_rate : int or tuple of ints
1180-
The dilation rate of the convolution.
1181-
It can be a single integer if you want to specify the same value for all spatial dimensions.
1182-
Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1.
1183-
depth_multiplier : int
1184-
The number of depthwise convolution output channels for each input channel.
1185-
The total number of depthwise convolution output channels will be equal to num_filters_in * depth_multiplier.
1186-
act : activation function
1187-
The activation function of this layer.
1188-
use_bias : boolean
1189-
Whether the layer uses a bias
1190-
depthwise_initializer : initializer
1191-
The initializer for the depthwise convolution kernel.
1192-
pointwise_initializer : initializer
1193-
The initializer for the pointwise convolution kernel.
1194-
bias_initializer : initializer
1195-
The initializer for the bias vector. If None, skip bias.
1196-
depthwise_regularizer : regularizer
1197-
Optional regularizer for the depthwise convolution kernel.
1198-
pointwise_regularizer : regularizer
1199-
Optional regularizer for the pointwise convolution kernel.
1200-
bias_regularizer : regularizer
1201-
Optional regularizer for the bias vector.
1202-
activity_regularizer : regularizer
1203-
Regularizer function for the output.
1204-
name : str
1205-
A unique layer name.
1206-
1207-
"""
1208-
1209-
@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
1210-
def __init__(
1211-
self, prev_layer, n_filter, filter_size=5, strides=(1, 1), padding='valid', data_format='channels_last',
1212-
dilation_rate=(1, 1), depth_multiplier=1, act=tf.identity, use_bias=True, depthwise_initializer=None,
1213-
pointwise_initializer=None, bias_initializer=tf.zeros_initializer, depthwise_regularizer=None,
1214-
pointwise_regularizer=None, bias_regularizer=None, activity_regularizer=None, name='atrou2d'
1215-
):
1216-
1217-
super(_SeparableConv2dLayer, self).__init__(prev_layer=prev_layer, name=name)
1218-
logging.info(
1219-
"SeparableConv2dLayer %s: n_filter:%d filter_size:%s strides:%s padding:%s dilation_rate:%s depth_multiplier:%s act:%s"
1220-
% (
1221-
name, n_filter, filter_size, str(strides), padding, str(dilation_rate), str(depth_multiplier),
1222-
act.__name__
1223-
)
1224-
)
1225-
1226-
self.inputs = prev_layer.outputs
1227-
1228-
if tf.__version__ > "0.12.1":
1229-
raise Exception("This layer only supports for TF 1.0+")
1230-
1231-
bias_initializer = bias_initializer()
1232-
1233-
with tf.variable_scope(name) as vs:
1234-
self.outputs = tf.layers.separable_conv2d(
1235-
self.inputs,
1236-
filters=n_filter,
1237-
kernel_size=filter_size,
1238-
strides=strides,
1239-
padding=padding,
1240-
data_format=data_format,
1241-
dilation_rate=dilation_rate,
1242-
depth_multiplier=depth_multiplier,
1243-
activation=act,
1244-
use_bias=use_bias,
1245-
depthwise_initializer=depthwise_initializer,
1246-
pointwise_initializer=pointwise_initializer,
1247-
bias_initializer=bias_initializer,
1248-
depthwise_regularizer=depthwise_regularizer,
1249-
pointwise_regularizer=pointwise_regularizer,
1250-
bias_regularizer=bias_regularizer,
1251-
activity_regularizer=activity_regularizer,
1252-
)
1253-
# trainable=True, name=None, reuse=None)
1254-
1255-
variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
1256-
1257-
# self.all_layers = list(layer.all_layers)
1258-
# self.all_params = list(layer.all_params)
1259-
# self.all_drop = dict(layer.all_drop)
1260-
self.all_layers.append(self.outputs)
1261-
self.all_params.extend(variables)
1262-
1263-
12641156
def deconv2d_bilinear_upsampling_initializer(shape):
12651157
"""Returns the initializer that can be passed to DeConv2dLayer for initializ ingthe
12661158
weights in correspondence to channel-wise bilinear up-sampling.
@@ -1762,18 +1654,18 @@ def __init__(
17621654
self.inputs = prev_layer.outputs
17631655

17641656
with tf.variable_scope(name) as vs:
1765-
self.outputs = tf.contrib.layers.conv3d_transpose(
1766-
inputs=self.inputs,
1767-
num_outputs=n_filter,
1657+
nn = tf.layers.Conv3DTranspose(
1658+
filters=n_filter,
17681659
kernel_size=filter_size,
1769-
stride=strides,
1660+
strides=strides,
17701661
padding=padding,
1771-
activation_fn=act,
1772-
weights_initializer=W_init,
1773-
biases_initializer=b_init,
1774-
scope=name,
1662+
activation=act,
1663+
kernel_initializer=W_init,
1664+
bias_initializer=b_init,
1665+
name=None,
17751666
)
1776-
new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
1667+
self.outputs = nn(self.inputs)
1668+
new_variables = nn.weights # tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
17771669

17781670
self.all_layers.append(self.outputs)
17791671
self.all_params.extend(new_variables)
@@ -1908,6 +1800,113 @@ def __init__(
19081800
self.all_params.append(W)
19091801

19101802

1803+
class SeparableConv1d(Layer):
1804+
"""The :class:`SeparableConv1d` class is a 1D depthwise separable convolutional layer, see `tf.layers.separable_conv1d <https://www.tensorflow.org/api_docs/python/tf/layers/separable_conv1d>`__.
1805+
1806+
This layer performs a depthwise convolution that acts separately on channels, followed by a pointwise convolution that mixes channels.
1807+
1808+
Parameters
1809+
------------
1810+
prev_layer : :class:`Layer`
1811+
Previous layer.
1812+
n_filter : int
1813+
The dimensionality of the output space (i.e. the number of filters in the convolution).
1814+
filter_size : int
1815+
Specifying the spatial dimensions of the filters. Can be a single integer to specify the same value for all spatial dimensions.
1816+
strides : int
1817+
Specifying the stride of the convolution. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1.
1818+
padding : str
1819+
One of "valid" or "same" (case-insensitive).
1820+
data_format : str
1821+
One of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).
1822+
dilation_rate : int
1823+
Specifying the dilation rate to use for dilated convolution. Can be a single integer to specify the same value for all spatial dimensions. Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1.
1824+
depth_multiplier : int
1825+
The number of depthwise convolution output channels for each input channel. The total number of depthwise convolution output channels will be equal to num_filters_in * depth_multiplier.
1826+
depthwise_init : initializer
1827+
for the depthwise convolution kernel.
1828+
pointwise_init : initializer
1829+
For the pointwise convolution kernel.
1830+
b_init : initializer
1831+
For the bias vector. If None, ignore bias in the pointwise part only.
1832+
name : a str
1833+
A unique layer name.
1834+
1835+
"""
1836+
1837+
@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
1838+
def __init__(
1839+
self,
1840+
prev_layer,
1841+
n_filter=100,
1842+
filter_size=3,
1843+
strides=1,
1844+
act=tf.identity,
1845+
padding='valid',
1846+
data_format='channels_last',
1847+
dilation_rate=1,
1848+
depth_multiplier=1,
1849+
# activation=None,
1850+
# use_bias=True,
1851+
depthwise_init=None,
1852+
pointwise_init=None,
1853+
b_init=tf.zeros_initializer(),
1854+
# depthwise_regularizer=None,
1855+
# pointwise_regularizer=None,
1856+
# bias_regularizer=None,
1857+
# activity_regularizer=None,
1858+
# depthwise_constraint=None,
1859+
# pointwise_constraint=None,
1860+
# W_init=tf.truncated_normal_initializer(stddev=0.1),
1861+
# b_init=tf.constant_initializer(value=0.0),
1862+
# W_init_args=None,
1863+
# b_init_args=None,
1864+
name='seperable1d',
1865+
):
1866+
# if W_init_args is None:
1867+
# W_init_args = {}
1868+
# if b_init_args is None:
1869+
# b_init_args = {}
1870+
1871+
super(SeparableConv1d, self).__init__(prev_layer=prev_layer, name=name)
1872+
logging.info(
1873+
"SeparableConv1d %s: n_filter:%d filter_size:%s filter_size:%s depth_multiplier:%d act:%s" %
1874+
(self.name, n_filter, str(filter_size), str(strides), depth_multiplier, act.__name__)
1875+
)
1876+
1877+
self.inputs = prev_layer.outputs
1878+
1879+
with tf.variable_scope(name) as vs:
1880+
nn = tf.layers.SeparableConv1D(
1881+
filters=n_filter,
1882+
kernel_size=filter_size,
1883+
strides=strides,
1884+
padding=padding,
1885+
data_format=data_format,
1886+
dilation_rate=dilation_rate,
1887+
depth_multiplier=depth_multiplier,
1888+
activation=act,
1889+
use_bias=(True if b_init is not None else False),
1890+
depthwise_initializer=depthwise_init,
1891+
pointwise_initializer=pointwise_init,
1892+
bias_initializer=b_init,
1893+
# depthwise_regularizer=None,
1894+
# pointwise_regularizer=None,
1895+
# bias_regularizer=None,
1896+
# activity_regularizer=None,
1897+
# depthwise_constraint=None,
1898+
# pointwise_constraint=None,
1899+
# bias_constraint=None,
1900+
trainable=True,
1901+
name=None
1902+
)
1903+
self.outputs = nn(self.inputs)
1904+
new_variables = nn.weights
1905+
1906+
self.all_layers.append(self.outputs)
1907+
self.all_params.extend(new_variables)
1908+
1909+
19111910
class SeparableConv2d(Layer):
19121911
"""The :class:`SeparableConv2d` class is a 2D depthwise separable convolutional layer, see `tf.layers.separable_conv2d <https://www.tensorflow.org/api_docs/python/tf/layers/separable_conv2d>`__.
19131912
@@ -1986,8 +1985,7 @@ def __init__(
19861985
self.inputs = prev_layer.outputs
19871986

19881987
with tf.variable_scope(name) as vs:
1989-
self.outputs = tf.layers.separable_conv2d(
1990-
inputs=self.inputs,
1988+
nn = tf.layers.SeparableConv2D(
19911989
filters=n_filter,
19921990
kernel_size=filter_size,
19931991
strides=strides,
@@ -2010,7 +2008,9 @@ def __init__(
20102008
trainable=True,
20112009
name=None
20122010
)
2013-
new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
2011+
self.outputs = nn(self.inputs)
2012+
new_variables = nn.weights
2013+
# new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
20142014

20152015
self.all_layers.append(self.outputs)
20162016
self.all_params.extend(new_variables)

tests/test_layers_convolution.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ def setUpClass(cls):
2323
n2 = tl.layers.Conv1d(nin1, n_filter=32, filter_size=5, stride=2)
2424
cls.shape_n2 = n2.outputs.get_shape().as_list()
2525

26+
n2_1 = tl.layers.SeparableConv1d(
27+
nin1, n_filter=32, filter_size=3, strides=1, padding='VALID', act=tf.nn.relu, name='seperable1d1'
28+
)
29+
cls.shape_n2_1 = n2_1.outputs.get_shape().as_list()
30+
cls.n2_1_all_layers = n2_1.all_layers
31+
cls.n2_1_params = n2_1.all_params
32+
cls.n2_1_count_params = n2_1.count_params()
33+
2634
############
2735
# 2D #
2836
############
@@ -65,7 +73,7 @@ def setUpClass(cls):
6573
cls.shape_n9 = n9.outputs.get_shape().as_list()
6674

6775
n10 = tl.layers.SeparableConv2d(
68-
nin2, n_filter=32, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, name='seperable1'
76+
nin2, n_filter=32, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, name='seperable2d1'
6977
)
7078
cls.shape_n10 = n10.outputs.get_shape().as_list()
7179
cls.n10_all_layers = n10.all_layers
@@ -101,6 +109,10 @@ def test_shape_n2(self):
101109
self.assertEqual(self.shape_n2[1], 50)
102110
self.assertEqual(self.shape_n2[2], 32)
103111

112+
def test_shape_n2_1(self):
113+
self.assertEqual(self.shape_n2_1[1], 98)
114+
self.assertEqual(self.shape_n2_1[2], 32)
115+
104116
def test_shape_n3(self):
105117
self.assertEqual(self.shape_n3[1], 50)
106118
self.assertEqual(self.shape_n3[2], 50)
@@ -151,6 +163,9 @@ def test_shape_n12(self):
151163
self.assertEqual(self.shape_n12[3], 200)
152164
self.assertEqual(self.shape_n12[4], 32)
153165

166+
def test_params_n2_1(self):
167+
self.assertEqual(len(self.n2_1_params), 3)
168+
154169
def test_params_n4(self):
155170
self.assertEqual(len(self.n4_params), 2)
156171

@@ -161,6 +176,9 @@ def test_params_n10(self):
161176
self.assertEqual(len(self.n10_params), 3)
162177
self.assertEqual(self.n10_count_params, 155)
163178

179+
def test_layers_n2_1(self):
180+
self.assertEqual(len(self.n2_1_all_layers), 1)
181+
164182
def test_layers_n10(self):
165183
self.assertEqual(len(self.n10_all_layers), 1)
166184

0 commit comments

Comments
 (0)