Skip to content

Commit c493ac2

Browse files
author
Jonathan DEKHTIAR
authored
Ternary conv fix (#659)
* ternary conv fix * _compute_threshold deprecation warning fixed * Ternary Conv Layer added in unittest + Conv Tests Refactored * Changelog Updated * Codacy Error Fix * TL Logging Prefix for Warning/Error/Fatal added
1 parent dbba3ba commit c493ac2

File tree

6 files changed

+208
-155
lines changed

6 files changed

+208
-155
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ To release a new version, please update the changelog as followed:
107107
- `tl.files` refactored into a directory with numerous files (by @DEKHTIARJonathan in #657)
108108
- `tl.files.voc_dataset` fixed because of original Pascal VOC website was down (by @DEKHTIARJonathan in #657)
109109
- extra requirements hidden inside the library added in the project requirements (by @DEKHTIARJonathan in #657)
110-
- requirements files refactored in `requirements/` directory (by @DEKHTIARJonathan in #657)
110+
- requirements files refactored in `requirements/` directory (by @DEKHTIARJonathan in #657)
111+
- Ternary Convolution Layer added in unittest (by @DEKHTIARJonathan in #658)
112+
- Convolution Layers unittests have been cleaned & refactored (by @DEKHTIARJonathan in #658)
111113

112114
### Deprecated
113115

@@ -119,6 +121,8 @@ To release a new version, please update the changelog as followed:
119121
- Issue #565 related to `tl.utils.predict` fixed - `np.hstack` problem in which the results for multiple batches are stacked along `axis=1` (by @2wins in #566)
120122
- Issue #572 with `tl.layers.DeformableConv2d` fixed (by @DEKHTIARJonathan in #573)
121123
- Typo of the document of ElementwiseLambdaLayer (by @zsdonghao in #588)
124+
- Error in `tl.layers.TernaryConv2d` fixed - self.inputs not defined (by @DEKHTIARJonathan in #658)
125+
- Deprecation warning fixed in `tl.layers.binary._compute_threshold()` (by @DEKHTIARJonathan in #658)
122126

123127
### Security
124128

tensorlayer/files/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2087,7 +2087,9 @@ def _dlProgress(count, blockSize, totalSize, pbar=progress_bar):
20872087
pbar.update(count, force=True)
20882088

20892089
filepath = os.path.join(working_directory, filename)
2090-
sys.stdout.write('Downloading %s...\n' % filename)
2090+
2091+
logging.info('Downloading %s...\n' % filename)
2092+
20912093
urlretrieve(url_source + filename, filepath, reporthook=_dlProgress)
20922094

20932095
exists_or_mkdir(working_directory, verbose=False)

tensorlayer/layers/binary.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _compute_threshold(x):
6565
ref: https://github.com/XJTUWYD/TWN
6666
Computing the threshold.
6767
"""
68-
x_sum = tf.reduce_sum(tf.abs(x), reduction_indices=None, keep_dims=False, name=None)
68+
x_sum = tf.reduce_sum(tf.abs(x), reduction_indices=None, keepdims=False, name=None)
6969
threshold = tf.div(x_sum, tf.cast(tf.size(x), tf.float32), name=None)
7070
threshold = tf.multiply(0.7, threshold, name=None)
7171
return threshold
@@ -408,6 +408,7 @@ def __init__(
408408
# self.outputs = act(xnor_gemm(self.inputs, W)) # TODO
409409

410410
self.all_layers.append(self.outputs)
411+
411412
if b_init is not None:
412413
self.all_params.extend([W, b])
413414
else:
@@ -500,31 +501,43 @@ def __init__(
500501
(name, n_filter, str(filter_size), str(strides), padding, act.__name__)
501502
)
502503

504+
if len(strides) != 2:
505+
raise ValueError("len(strides) should be 2.")
506+
507+
if use_gemm:
508+
raise Exception("TODO. The current version use tf.matmul for inferencing.")
509+
503510
if W_init_args is None:
504511
W_init_args = {}
512+
505513
if b_init_args is None:
506514
b_init_args = {}
515+
507516
if act is None:
508517
act = tf.identity
509-
if use_gemm:
510-
raise Exception("TODO. The current version use tf.matmul for inferencing.")
511518

512-
if len(strides) != 2:
513-
raise ValueError("len(strides) should be 2.")
514519
try:
515520
pre_channel = int(prev_layer.outputs.get_shape()[-1])
516521
except Exception: # if pre_channel is ?, it happens when using Spatial Transformer Net
517522
pre_channel = 1
518523
logging.info("[warnings] unknow input channels, set to 1")
524+
519525
shape = (filter_size[0], filter_size[1], pre_channel, n_filter)
520526
strides = (1, strides[0], strides[1], 1)
527+
528+
self.inputs = prev_layer.outputs
529+
521530
with tf.variable_scope(name):
531+
522532
W = tf.get_variable(
523533
name='W_conv2d', shape=shape, initializer=W_init, dtype=LayersConfig.tf_dtype, **W_init_args
524534
)
535+
525536
alpha = _compute_alpha(W)
537+
526538
W = _ternary_operation(W)
527539
W = tf.multiply(alpha, W)
540+
528541
if b_init:
529542
b = tf.get_variable(
530543
name='b_conv2d', shape=(shape[-1]), initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args
@@ -535,6 +548,7 @@ def __init__(
535548
data_format=data_format
536549
) + b
537550
)
551+
538552
else:
539553
self.outputs = act(
540554
tf.nn.conv2d(
@@ -544,6 +558,7 @@ def __init__(
544558
)
545559

546560
self.all_layers.append(self.outputs)
561+
547562
if b_init:
548563
self.all_params.extend([W, b])
549564
else:

tensorlayer/tl_logging.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,24 +86,24 @@ def debug(msg, *args, **kwargs):
8686
_get_logger().debug(msg, *args, **kwargs)
8787

8888

89-
def error(msg, *args, **kwargs):
90-
_get_logger().error(msg, *args, **kwargs)
89+
def info(msg, *args, **kwargs):
90+
_get_logger().info(msg, *args, **kwargs)
9191

9292

93-
def fatal(msg, *args, **kwargs):
94-
_get_logger().fatal(msg, *args, **kwargs)
93+
def error(msg, *args, **kwargs):
94+
_get_logger().error("ERROR: %s" % msg, *args, **kwargs)
9595

9696

97-
def info(msg, *args, **kwargs):
98-
_get_logger().info(msg, *args, **kwargs)
97+
def fatal(msg, *args, **kwargs):
98+
_get_logger().fatal("FATAL: %s" % msg, *args, **kwargs)
9999

100100

101101
def warn(msg, *args, **kwargs):
102-
_get_logger().warning(msg, *args, **kwargs)
102+
warning(msg, *args, **kwargs)
103103

104104

105105
def warning(msg, *args, **kwargs):
106-
_get_logger().warning(msg, *args, **kwargs)
106+
_get_logger().warning("WARNING: %s" % msg, *args, **kwargs)
107107

108108

109109
# Mask to convert integer thread ids to unsigned quantities for logging

tensorlayer/visualize.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,21 @@
22

33
import os
44

5-
try:
6-
import cv2
7-
except ImportError:
8-
import warnings
9-
warnings.simplefilter('default', ImportWarning)
10-
warnings.warn(
11-
message='[TL] Warning: OpenCV Library is not installed.\n' \
12-
'The function `tl.visualize.draw_boxes_and_labels_to_image` will not be able to work.',
13-
category=ImportWarning
14-
)
15-
165
import imageio
176

187
import numpy as np
198

209
from tensorlayer import tl_logging as logging
2110
from tensorlayer import prepro
2211

12+
try:
13+
import cv2
14+
except ImportError:
15+
logging.warn(
16+
'OpenCV Library is not installed.'
17+
'The function `tl.visualize.draw_boxes_and_labels_to_image` will not be able to work.'
18+
)
19+
2320
# Uncomment the following line if you got: _tkinter.TclError: no display name and no $DISPLAY environment variable
2421
# import matplotlib
2522
# matplotlib.use('Agg')

0 commit comments

Comments
 (0)