Skip to content

Commit 677cbde

Browse files
committed
Merge branch 'Add_load_ckpt' of github.com:Laicheng0830/tensorlayer into Add_load_ckpt
2 parents ae38662 + 9765f0c commit 677cbde

File tree

5 files changed

+52
-5
lines changed

5 files changed

+52
-5
lines changed

docs/modules/activation.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ For more complex activation, TensorFlow API will be required.
3535
sign
3636
hard_tanh
3737
pixel_wise_softmax
38+
mish
3839

3940
Ramp
4041
------
@@ -68,6 +69,10 @@ Pixel-wise softmax
6869
--------------------
6970
.. autofunction:: pixel_wise_softmax
7071

72+
mish
73+
---------
74+
.. autofunction:: mish
75+
7176
Parametric activation
7277
------------------------------
7378
See ``tensorlayer.layers``.

docs/user/contributing.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ For TensorLayer 1.x, it was actively developed and maintained by the following p
4040
- **Hao Dong** (`@zsdonghao <https://github.com/zsdonghao>`_) - `<https://zsdonghao.github.io>`_
4141
- **Jonathan Dekhtiar** (`@DEKHTIARJonathan <https://github.com/DEKHTIARJonathan>`_) - `<https://www.jonathandekhtiar.eu>`_
4242
- **Luo Mai** (`@luomai <https://github.com/luomai>`_) - `<http://www.doc.ic.ac.uk/~lm111/>`_
43+
- **Pan Wang** (`@FerociousPanda <http://github.com/FerociousPanda>`_) - `<http://github.com/FerociousPanda>`_ (UI)
4344
- **Simiao Yu** (`@nebulaV <https://github.com/nebulaV>`_) - `<https://nebulav.github.io>`_
4445

46+
4547
Numerous other contributors can be found in the `Github Contribution Graph <https://github.com/tensorlayer/tensorlayer/graphs/contributors>`_.
4648

4749

tensorlayer/activation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'htanh',
2020
'hard_tanh',
2121
'pixel_wise_softmax',
22+
'mish',
2223
]
2324

2425

@@ -339,6 +340,25 @@ def pixel_wise_softmax(x, name='pixel_wise_softmax'):
339340
return tf.nn.softmax(x)
340341

341342

343+
def mish(x):
344+
"""Mish activation function.
345+
346+
Reference: [Mish: A Self Regularized Non-Monotonic Neural Activation Function .Diganta Misra, 2019]<https://arxiv.org/abs/1908.08681>
347+
348+
Parameters
349+
----------
350+
x : Tensor
351+
input.
352+
353+
Returns
354+
-------
355+
Tensor
356+
A ``Tensor`` in the same type as ``x``.
357+
358+
"""
359+
return x * tf.math.tanh(tf.math.softplus(x))
360+
361+
342362
# Alias
343363
lrelu = leaky_relu
344364
lrelu6 = leaky_relu6

tensorlayer/layers/normalization.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ def _bias_add(x, b, data_format):
107107
raise ValueError('invalid data_format: %s' % data_format)
108108

109109

110+
def _compute_shape(tensors):
111+
if isinstance(tensors, list):
112+
shape_mem = [t.get_shape().as_list() for t in tensors]
113+
else:
114+
shape_mem = tensors.get_shape().as_list()
115+
return shape_mem
116+
117+
110118
def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, data_format, name=None):
111119
"""Data Format aware version of tf.nn.batch_normalization."""
112120
if data_format == 'channels_last':
@@ -256,7 +264,8 @@ def _get_param_shape(self, inputs_shape):
256264
return params_shape
257265

258266
def _check_input_shape(self, inputs):
259-
if inputs.ndim <= 1:
267+
inputs_shape = _compute_shape(inputs)
268+
if len(inputs_shape) <= 1:
260269
raise ValueError('expected input at least 2D, but got {}D input'.format(inputs.ndim))
261270

262271
def build(self, inputs_shape):
@@ -318,7 +327,8 @@ class BatchNorm1d(BatchNorm):
318327
"""
319328

320329
def _check_input_shape(self, inputs):
321-
if inputs.ndim != 2 and inputs.ndim != 3:
330+
inputs_shape = _compute_shape(inputs)
331+
if len(inputs_shape) != 2 and len(inputs_shape) != 3:
322332
raise ValueError('expected input to be 2D or 3D, but got {}D input'.format(inputs.ndim))
323333

324334

@@ -341,7 +351,8 @@ class BatchNorm2d(BatchNorm):
341351
"""
342352

343353
def _check_input_shape(self, inputs):
344-
if inputs.ndim != 4:
354+
inputs_shape = _compute_shape(inputs)
355+
if len(inputs_shape) != 4:
345356
raise ValueError('expected input to be 4D, but got {}D input'.format(inputs.ndim))
346357

347358

@@ -364,7 +375,8 @@ class BatchNorm3d(BatchNorm):
364375
"""
365376

366377
def _check_input_shape(self, inputs):
367-
if inputs.ndim != 5:
378+
inputs_shape = _compute_shape(inputs)
379+
if len(inputs_shape) != 5:
368380
raise ValueError('expected input to be 5D, but got {}D input'.format(inputs.ndim))
369381

370382

tests/test_activations.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import unittest
66

77
import tensorflow as tf
8-
8+
import numpy as np
99
import tensorlayer as tl
1010
from tests.utils import CustomTestCase
1111

@@ -116,6 +116,14 @@ def test_swish(self):
116116

117117
self.assertAlmostEqual(computed_output.numpy(), good_output, places=5)
118118

119+
def test_mish(self):
120+
for i in range(-5, 15):
121+
good_output = i * np.tanh(np.math.log(1 + np.math.exp(i)))
122+
123+
computed_output = tl.act.mish(float(i))
124+
125+
self.assertAlmostEqual(computed_output.numpy(), good_output, places=5)
126+
119127

120128
if __name__ == '__main__':
121129

0 commit comments

Comments
 (0)