Skip to content

Commit a55daa6

Browse files
zsdonghaoJonathan DEKHTIAR
authored andcommitted
Fix bug for tf.layers when reuse=True (#685)
* [WIP] rearrange readme and example list * fixing Conv2d reuse bug * remove print * fix bug ! * fix bug * fix all CNN layers use tf.layers * run yapf * changelog * reposition activation parameter * activation parameter moved in doc * fix bug with new function * fix bug with new function * test TF Layers for reuse=True added * 1.8.6rc5 release * Codacy Cleaning
1 parent b5b3b22 commit a55daa6

File tree

6 files changed

+318
-78
lines changed

6 files changed

+318
-78
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ To release a new version, please update the changelog as followed:
9595
- Layer:
9696
- ElementwiseLambdaLayer added to use custom function to connect multiple layer inputs (by @One-sixth in #579)
9797
- AtrousDeConv2dLayer added (by @2wins in #662)
98+
- Fix bugs of using `tf.layers` in CNN (by @zsdonghao in #686)
9899
- Optimizer:
99100
- AMSGrad Optimizer added based on `On the Convergence of Adam and Beyond (ICLR 2018)` (by @DEKHTIARJonathan in #636)
100101
- Setup:
@@ -306,5 +307,5 @@ To release a new version, please update the changelog as followed:
306307
@zsdonghao @luomai @DEKHTIARJonathan
307308

308309
[Unreleased]: https://github.com/tensorlayer/tensorlayer/compare/1.8.5...master
309-
[1.8.6]: https://github.com/tensorlayer/tensorlayer/compare/1.8.6rc4...1.8.5
310+
[1.8.6]: https://github.com/tensorlayer/tensorlayer/compare/1.8.6rc5...1.8.5
310311
[1.8.5]: https://github.com/tensorlayer/tensorlayer/compare/1.8.4...1.8.5

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
![PyPI Stable Version](http://ec2-35-178-47-120.eu-west-2.compute.amazonaws.com/github/release/tensorlayer/tensorlayer.svg?label=PyPI%20-%20Release)
1616
![PyPI RC Version](http://ec2-35-178-47-120.eu-west-2.compute.amazonaws.com/github/release/tensorlayer/tensorlayer/all.svg?label=PyPI%20-%20Pre-Release)
17-
[![Github commits (since latest release)](http://ec2-35-178-47-120.eu-west-2.compute.amazonaws.com/github/commits-since/tensorlayer/tensorlayer/latest.svg)](https://github.com/tensorlayer/tensorlayer/compare/1.8.6rc4...master)
17+
[![Github commits (since latest release)](http://ec2-35-178-47-120.eu-west-2.compute.amazonaws.com/github/commits-since/tensorlayer/tensorlayer/latest.svg)](https://github.com/tensorlayer/tensorlayer/compare/1.8.6rc5...master)
1818
[![PyPI - Python Version](http://ec2-35-178-47-120.eu-west-2.compute.amazonaws.com/pypi/pyversions/tensorlayer.svg)](https://pypi.org/project/tensorlayer/)
1919
[![Supported TF Version](http://ec2-35-178-47-120.eu-west-2.compute.amazonaws.com/badge/tensorflow-1.6.0+-blue.svg)](https://github.com/tensorflow/tensorflow/releases)
2020

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
:target: https://pypi.org/project/tensorlayer/
4141

4242
.. image:: http://ec2-35-178-47-120.eu-west-2.compute.amazonaws.com/github/commits-since/tensorlayer/tensorlayer/latest.svg
43-
:target: https://github.com/tensorlayer/tensorlayer/compare/1.8.6rc4...master
43+
:target: https://github.com/tensorlayer/tensorlayer/compare/1.8.6rc5...master
4444

4545
.. image:: http://ec2-35-178-47-120.eu-west-2.compute.amazonaws.com/pypi/pyversions/tensorlayer.svg
4646
:target: https://pypi.org/project/tensorlayer/

tensorlayer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
global_dict = {}
5757

5858
# Use the following formating: (major, minor, patch, prerelease)
59-
VERSION = (1, 8, 6, "rc4")
59+
VERSION = (1, 8, 6, "rc5")
6060
__shortversion__ = '.'.join(map(str, VERSION[:3]))
6161
__version__ = '.'.join(map(str, VERSION[:3])) + "".join(VERSION[3:])
6262

tensorlayer/layers/convolution.py

Lines changed: 98 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@
3636
]
3737

3838

39+
def _get_collection_trainable(name=''):
40+
variables = []
41+
for p in tf.trainable_variables():
42+
# print(p.name.rpartition('/')[0], self.name)
43+
if p.name.rpartition('/')[0] == name:
44+
variables.append(p)
45+
return variables
46+
47+
3948
class Conv1dLayer(Layer):
4049
"""
4150
The :class:`Conv1dLayer` class is a 1D CNN layer, see `tf.nn.convolution <https://www.tensorflow.org/api_docs/python/tf/nn/convolution>`__.
@@ -382,15 +391,15 @@ class Conv3dLayer(Layer):
382391
----------
383392
prev_layer : :class:`Layer`
384393
Previous layer.
385-
act : activation function
386-
The activation function of this layer.
387394
shape : tuple of int
388395
Shape of the filters: (filter_depth, filter_height, filter_width, in_channels, out_channels).
389396
strides : tuple of int
390397
The sliding window strides for corresponding input dimensions.
391398
Must be in the same order as the shape dimension.
392399
padding : str
393400
The padding algorithm type: "SAME" or "VALID".
401+
act : activation function
402+
The activation function of this layer.
394403
W_init : initializer
395404
The initializer for the weight matrix.
396405
b_init : initializer or None
@@ -414,10 +423,10 @@ class Conv3dLayer(Layer):
414423
def __init__(
415424
self,
416425
prev_layer,
417-
act=None,
418426
shape=(2, 2, 2, 3, 32),
419427
strides=(1, 2, 2, 2, 1),
420428
padding='SAME',
429+
act=None,
421430
W_init=tf.truncated_normal_initializer(stddev=0.02),
422431
b_init=tf.constant_initializer(value=0.0),
423432
W_init_args=None,
@@ -1335,7 +1344,9 @@ def __init__(
13351344

13361345
# _conv1d.dtype = LayersConfig.tf_dtype # unsupport, it will use the same dtype of inputs
13371346
self.outputs = _conv1d(self.inputs)
1338-
new_variables = _conv1d.weights # new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
1347+
# new_variables = _conv1d.weights # new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
1348+
# new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=self.name) #vs.name)
1349+
new_variables = _get_collection_trainable(self.name)
13391350

13401351
self._add_layers(self.outputs)
13411352
self._add_params(new_variables)
@@ -1455,11 +1466,23 @@ def __init__(
14551466
name=name,
14561467
# reuse=None,
14571468
)
1458-
1459-
self.outputs = conv2d(self.inputs)
1469+
self.outputs = conv2d(self.inputs) # must put before ``new_variables``
1470+
# new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=self.name) #vs.name)
1471+
new_variables = _get_collection_trainable(self.name)
1472+
# new_variables = []
1473+
# for p in tf.trainable_variables():
1474+
# # print(p.name.rpartition('/')[0], self.name)
1475+
# if p.name.rpartition('/')[0] == self.name:
1476+
# new_variables.append(p)
1477+
# exit()
1478+
# TF_GRAPHKEYS_VARIABLES TF_GRAPHKEYS_VARIABLES
1479+
# print(self.name, name)
1480+
# print(tf.trainable_variables())#tf.GraphKeys.TRAINABLE_VARIABLES)
1481+
# print(new_variables)
1482+
# print(conv2d.weights)
14601483

14611484
self._add_layers(self.outputs)
1462-
self._add_params(conv2d.weights)
1485+
self._add_params(new_variables) #conv2d.weights)
14631486

14641487

14651488
class DeConv2d(Layer):
@@ -1535,7 +1558,9 @@ def __init__(
15351558
)
15361559

15371560
self.outputs = conv2d_transpose(self.inputs)
1538-
new_variables = conv2d_transpose.weights # new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
1561+
# new_variables = conv2d_transpose.weights # new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
1562+
# new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=self.name) #vs.name)
1563+
new_variables = _get_collection_trainable(self.name)
15391564

15401565
self._add_layers(self.outputs)
15411566
self._add_params(new_variables)
@@ -1597,21 +1622,16 @@ def __init__(
15971622
)
15981623
)
15991624

1600-
with tf.variable_scope(name) as vs:
1601-
1602-
nn = tf.layers.Conv3DTranspose(
1603-
filters=n_filter,
1604-
kernel_size=filter_size,
1605-
strides=strides,
1606-
padding=padding,
1607-
activation=self.act,
1608-
kernel_initializer=W_init,
1609-
bias_initializer=b_init,
1610-
name=None,
1611-
)
1625+
# with tf.variable_scope(name) as vs:
1626+
nn = tf.layers.Conv3DTranspose(
1627+
filters=n_filter, kernel_size=filter_size, strides=strides, padding=padding, activation=self.act,
1628+
kernel_initializer=W_init, bias_initializer=b_init, name=name
1629+
)
16121630

1613-
self.outputs = nn(self.inputs)
1614-
new_variables = nn.weights # tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
1631+
self.outputs = nn(self.inputs)
1632+
# new_variables = nn.weights # tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
1633+
# new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=self.name) #vs.name)
1634+
new_variables = _get_collection_trainable(self.name)
16151635

16161636
self._add_layers(self.outputs)
16171637
self._add_params(new_variables)
@@ -1814,33 +1834,35 @@ def __init__(
18141834
if self.act is not None else 'No Activation'
18151835
)
18161836
)
1817-
with tf.variable_scope(name) as vs:
1818-
nn = tf.layers.SeparableConv1D(
1819-
filters=n_filter,
1820-
kernel_size=filter_size,
1821-
strides=strides,
1822-
padding=padding,
1823-
data_format=data_format,
1824-
dilation_rate=dilation_rate,
1825-
depth_multiplier=depth_multiplier,
1826-
activation=self.act,
1827-
use_bias=(True if b_init is not None else False),
1828-
depthwise_initializer=depthwise_init,
1829-
pointwise_initializer=pointwise_init,
1830-
bias_initializer=b_init,
1831-
# depthwise_regularizer=None,
1832-
# pointwise_regularizer=None,
1833-
# bias_regularizer=None,
1834-
# activity_regularizer=None,
1835-
# depthwise_constraint=None,
1836-
# pointwise_constraint=None,
1837-
# bias_constraint=None,
1838-
trainable=True,
1839-
name=None
1840-
)
1837+
# with tf.variable_scope(name) as vs:
1838+
nn = tf.layers.SeparableConv1D(
1839+
filters=n_filter,
1840+
kernel_size=filter_size,
1841+
strides=strides,
1842+
padding=padding,
1843+
data_format=data_format,
1844+
dilation_rate=dilation_rate,
1845+
depth_multiplier=depth_multiplier,
1846+
activation=self.act,
1847+
use_bias=(True if b_init is not None else False),
1848+
depthwise_initializer=depthwise_init,
1849+
pointwise_initializer=pointwise_init,
1850+
bias_initializer=b_init,
1851+
# depthwise_regularizer=None,
1852+
# pointwise_regularizer=None,
1853+
# bias_regularizer=None,
1854+
# activity_regularizer=None,
1855+
# depthwise_constraint=None,
1856+
# pointwise_constraint=None,
1857+
# bias_constraint=None,
1858+
trainable=True,
1859+
name=name
1860+
)
18411861

1842-
self.outputs = nn(self.inputs)
1843-
new_variables = nn.weights
1862+
self.outputs = nn(self.inputs)
1863+
# new_variables = nn.weights
1864+
# new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=self.name) #vs.name)
1865+
new_variables = _get_collection_trainable(self.name)
18441866

18451867
self._add_layers(self.outputs)
18461868
self._add_params(new_variables)
@@ -1925,33 +1947,35 @@ def __init__(
19251947
)
19261948
)
19271949

1928-
with tf.variable_scope(name) as vs:
1929-
nn = tf.layers.SeparableConv2D(
1930-
filters=n_filter,
1931-
kernel_size=filter_size,
1932-
strides=strides,
1933-
padding=padding,
1934-
data_format=data_format,
1935-
dilation_rate=dilation_rate,
1936-
depth_multiplier=depth_multiplier,
1937-
activation=self.act,
1938-
use_bias=(True if b_init is not None else False),
1939-
depthwise_initializer=depthwise_init,
1940-
pointwise_initializer=pointwise_init,
1941-
bias_initializer=b_init,
1942-
# depthwise_regularizer=None,
1943-
# pointwise_regularizer=None,
1944-
# bias_regularizer=None,
1945-
# activity_regularizer=None,
1946-
# depthwise_constraint=None,
1947-
# pointwise_constraint=None,
1948-
# bias_constraint=None,
1949-
trainable=True,
1950-
name=None
1951-
)
1950+
# with tf.variable_scope(name) as vs:
1951+
nn = tf.layers.SeparableConv2D(
1952+
filters=n_filter,
1953+
kernel_size=filter_size,
1954+
strides=strides,
1955+
padding=padding,
1956+
data_format=data_format,
1957+
dilation_rate=dilation_rate,
1958+
depth_multiplier=depth_multiplier,
1959+
activation=self.act,
1960+
use_bias=(True if b_init is not None else False),
1961+
depthwise_initializer=depthwise_init,
1962+
pointwise_initializer=pointwise_init,
1963+
bias_initializer=b_init,
1964+
# depthwise_regularizer=None,
1965+
# pointwise_regularizer=None,
1966+
# bias_regularizer=None,
1967+
# activity_regularizer=None,
1968+
# depthwise_constraint=None,
1969+
# pointwise_constraint=None,
1970+
# bias_constraint=None,
1971+
trainable=True,
1972+
name=name
1973+
)
19521974

1953-
self.outputs = nn(self.inputs)
1954-
new_variables = nn.weights
1975+
self.outputs = nn(self.inputs)
1976+
# new_variables = nn.weights
1977+
# new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=self.name) #vs.name)
1978+
new_variables = _get_collection_trainable(self.name)
19551979

19561980
self._add_layers(self.outputs)
19571981
self._add_params(new_variables)

0 commit comments

Comments
 (0)