Skip to content

Commit 979c1e0

Browse files
committed
update conv2d for use cudnn on gpu, data format
1 parent 40de1b2 commit 979c1e0

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

tensorlayer/layers.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,8 @@ class Conv2dLayer(Layer):
11981198
The arguments for the weights tf.get_variable().
11991199
b_init_args : dictionary
12001200
The arguments for the biases tf.get_variable().
1201+
use_cudnn_on_gpu : an optional string from: "NHWC", "NCHW". Defaults to "NHWC".
1202+
data_format : an optional bool. Defaults to True.
12011203
name : a string or None
12021204
An optional name to attach to this layer.
12031205
@@ -1245,6 +1247,8 @@ def __init__(
12451247
b_init = tf.constant_initializer(value=0.0),
12461248
W_init_args = {},
12471249
b_init_args = {},
1250+
use_cudnn_on_gpu = None,
1251+
data_format = None,
12481252
name ='cnn_layer',
12491253
):
12501254
Layer.__init__(self, name=name)
@@ -1256,9 +1260,9 @@ def __init__(
12561260
W = tf.get_variable(name='W_conv2d', shape=shape, initializer=W_init, **W_init_args )
12571261
if b_init:
12581262
b = tf.get_variable(name='b_conv2d', shape=(shape[-1]), initializer=b_init, **b_init_args )
1259-
self.outputs = act( tf.nn.conv2d(self.inputs, W, strides=strides, padding=padding) + b ) #1.2
1263+
self.outputs = act( tf.nn.conv2d(self.inputs, W, strides=strides, padding=padding, use_cudnn_on_gpu=use_cudnn_on_gpu, data_format=data_format) + b )
12601264
else:
1261-
self.outputs = act( tf.nn.conv2d(self.inputs, W, strides=strides, padding=padding))
1265+
self.outputs = act( tf.nn.conv2d(self.inputs, W, strides=strides, padding=padding, use_cudnn_on_gpu=use_cudnn_on_gpu, data_format=data_format))
12621266

12631267
self.all_layers = list(layer.all_layers)
12641268
self.all_params = list(layer.all_params)
@@ -1829,7 +1833,7 @@ def Conv1d(net, n_filter=32, filter_size=5, stride=1, act=None,
18291833

18301834
def Conv2d(net, n_filter=32, filter_size=(3, 3), strides=(1, 1), act = None,
18311835
padding='SAME', W_init = tf.truncated_normal_initializer(stddev=0.02), b_init = tf.constant_initializer(value=0.0),
1832-
W_init_args = {}, b_init_args = {}, name ='conv2d',):
1836+
W_init_args = {}, b_init_args = {}, use_cudnn_on_gpu = None, data_format = None,name ='conv2d',):
18331837
"""Wrapper for :class:`Conv2dLayer`, if you don't understand how to use :class:`Conv2dLayer`, this function may be easier.
18341838
18351839
Parameters
@@ -1865,6 +1869,8 @@ def Conv2d(net, n_filter=32, filter_size=(3, 3), strides=(1, 1), act = None,
18651869
W_init_args = W_init_args,
18661870
b_init = b_init,
18671871
b_init_args = b_init_args,
1872+
use_cudnn_on_gpu = use_cudnn_on_gpu,
1873+
data_format = data_format,
18681874
name = name)
18691875
return net
18701876

0 commit comments

Comments
 (0)