Skip to content

Commit b088cb3

Browse files
authored
Update depthwise parameters compatiable with multiple backends (#12)
1 parent c08d3fe commit b088cb3

File tree

6 files changed

+56
-70
lines changed

6 files changed

+56
-70
lines changed

tensorlayerx/backend/ops/mindspore_nn.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,21 +1000,25 @@ def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_
10001000

10011001
class DepthwiseConv2d(Cell):
10021002

1003-
def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1):
1003+
def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1, in_channels=None):
10041004
super(DepthwiseConv2d, self).__init__()
1005-
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
1006-
self.ms_stride = strides[1]
1007-
self.ms_dilation = dilations[1]
1008-
self.depthwise_conv2d = P.DepthwiseConv2dNative(
1009-
channel_multiplier=channel_multiplier, kernel_size=ksize, stride=self.ms_stride, dilation=self.ms_dilation
1010-
)
1005+
self.data_format, self.pad_mode = preprocess_2d_format(data_format, padding)
1006+
self.ms_stride = strides
1007+
self.ms_dilation = dilations
1008+
self.padding = 0
1009+
1010+
if isinstance(self.pad_mode, int) or isinstance(self.pad_mode, tuple):
1011+
self.padding = preprocess_padding(self.pad_mode, '2d')
1012+
self.pad_mode = "pad"
1013+
1014+
self.depth_conv = P.Conv2D(stride=self.ms_stride, pad_mode=self.pad_mode, pad=self.padding, kernel_size=ksize,
1015+
data_format=self.data_format, dilation=self.ms_dilation, group=in_channels, out_channel=in_channels)
1016+
self.point_conv = P.Conv2D(pad_mode=self.pad_mode, pad=self.padding, dilation=self.ms_dilation, kernel_size=1,
1017+
out_channel=channel_multiplier*in_channels, data_format=self.data_format)
10111018

10121019
def construct(self, input, filter, point_filter=None):
1013-
if self.data_format == 'NHWC':
1014-
input = nhwc_to_nchw(input)
1015-
outputs = self.depthwise_conv2d(input, filter)
1016-
if self.data_format == 'NHWC':
1017-
outputs = nchw_to_nhwc(outputs)
1020+
outputs = self.depth_conv(input, filter)
1021+
outputs = self.point_conv(outputs, point_filter)
10181022
return outputs
10191023

10201024

tensorlayerx/backend/ops/paddle_nn.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -897,22 +897,15 @@ def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_
897897

898898
class DepthwiseConv2d(object):
899899

900-
def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1):
900+
def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1, in_channels=None):
901901
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
902-
if self.data_format == 'NHWC':
903-
self._stride = (strides[1], strides[2])
904-
if self.data_format == 'NCHW':
905-
self._stride = (strides[2], strides[3])
902+
self._stride = strides
906903
self.dilations = dilations
907-
self.channel_multiplier = channel_multiplier
904+
self.in_channel = in_channels
908905

909906
def __call__(self, input, filter, point_filter=None):
910-
if self.data_format == 'NHWC':
911-
channel = input.shape[-1]
912-
elif self.data_format == 'NCHW':
913-
channel = input.shape[1]
914907
depthwise_conv = F.conv2d(
915-
input, filter, data_format=self.data_format, groups=channel, dilation=self.dilations, stride=self._stride,
908+
input, filter, data_format=self.data_format, groups=self.in_channel, dilation=self.dilations, stride=self._stride,
916909
padding=self.padding
917910
)
918911
pointwise_conv = F.conv2d(depthwise_conv, point_filter, data_format=self.data_format, padding=self.padding)

tensorlayerx/backend/ops/tensorflow_nn.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,23 +1098,29 @@ def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_
10981098

10991099
class DepthwiseConv2d(object):
11001100

1101-
def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1):
1101+
def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1, in_channels=None):
11021102
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
11031103
self.strides = strides
11041104
self.dilations = dilations
1105+
self.depthwise_conv = GroupConv2D(strides=self.strides, padding=self.padding, data_format=self.data_format,
1106+
out_channel=None, k_size=None, dilations=self.dilations, groups=in_channels)
11051107

11061108
def __call__(self, input, filter, point_filter=None):
1107-
outputs = tf.nn.depthwise_conv2d(
1108-
input=input,
1109-
filter=filter,
1110-
strides=self.strides,
1111-
padding=self.padding,
1112-
data_format=self.data_format,
1113-
dilations=self.dilations,
1114-
)
1109+
depthwise = self.depthwise_conv(input, filter)
1110+
outputs = tf.nn.conv2d(depthwise, point_filter, strides=1, padding=self.padding,
1111+
data_format=self.data_format, dilations=self.dilations)
1112+
# outputs = tf.nn.depthwise_conv2d(
1113+
# input=input,
1114+
# filter=filter,
1115+
# strides=self.strides,
1116+
# padding=self.padding,
1117+
# data_format=self.data_format,
1118+
# dilations=self.dilations,
1119+
# )
11151120
return outputs
11161121

11171122

1123+
11181124
def depthwise_conv2d(input, filter, strides, padding, data_format=None, dilations=None, name=None):
11191125
"""
11201126
Depthwise 2-D convolution.

tensorlayerx/backend/ops/torch_nn.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,25 +1104,22 @@ def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_
11041104

11051105
class DepthwiseConv2d(object):
11061106

1107-
def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1):
1107+
def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1, in_channels=None):
11081108
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
1109-
if self.data_format == 'NHWC':
1110-
self._stride = (strides[1], strides[2])
1111-
if self.data_format == 'NCHW':
1112-
self._stride = (strides[2], strides[3])
1113-
self.dilations = dilations
1109+
if self.data_format is 'NHWC':
1110+
self.strides = (1, strides[0], strides[1], 1)
1111+
self.dilations = (1, dilations[0], dilations[1], 1)
1112+
elif self.data_format is 'NCHW':
1113+
self.strides = (1, 1, strides[0], strides[1])
1114+
self.dilations = (1, 1, dilations[0], dilations[1])
1115+
self.depthwise = Conv2D(padding=self.padding, strides=self.strides, data_format=self.data_format,
1116+
dilations=self.dilations, groups=in_channels)
1117+
self.pointwise = Conv2D(strides=(1, 1, 1, 1), padding=self.padding, data_format=self.data_format, dilations=self.dilations, k_size=1)
11141118

11151119
def __call__(self, input, filter, point_filter=None):
1116-
if self.data_format == 'NHWC':
1117-
input = nhwc_to_nchw(input)
1118-
channel = input.shape[1]
1120+
depthwise_conv = self.depthwise(input, filter)
1121+
pointwise_conv = self.pointwise(depthwise_conv, point_filter)
11191122

1120-
depthwise_conv = F.conv2d(input, filter, bias=None, stride=self._stride, padding=self.padding,
1121-
dilation=self.dilations, groups=channel)
1122-
pointwise_conv = F.conv2d(depthwise_conv, point_filter, padding=self.padding)
1123-
1124-
if self.data_format == 'NHWC':
1125-
pointwise_conv = nchw_to_nhwc(pointwise_conv)
11261123
return pointwise_conv
11271124

11281125

tensorlayerx/nn/layers/convolution/depthwise_conv.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import tensorlayerx as tlx
55
from tensorlayerx import logging
66
from tensorlayerx.nn.core import Module
7-
from tensorlayerx.backend import BACKEND
87

98
__all__ = [
109
'DepthwiseConv2d',
@@ -80,9 +79,9 @@ def __init__(
8079
):
8180
super().__init__(name, act=act)
8281
self.kernel_size = self.check_param(kernel_size)
83-
self.stride = self._strides = self.check_param(stride)
82+
self.stride = self.check_param(stride)
8483
self.padding = padding
85-
self.dilation = self._dilation = self.check_param(dilation)
84+
self.dilation = self.check_param(dilation)
8685
self.data_format = data_format
8786
self.depth_multiplier = depth_multiplier
8887
self.W_init = self.str_to_init(W_init)
@@ -122,34 +121,21 @@ def build(self, inputs_shape):
122121
if self.data_format == 'channels_last':
123122
if self.in_channels is None:
124123
self.in_channels = inputs_shape[-1]
125-
self._strides = [1, self._strides[0], self._strides[1], 1]
126124
elif self.data_format == 'channels_first':
127125
if self.in_channels is None:
128126
self.in_channels = inputs_shape[1]
129-
self._strides = [1, 1, self._strides[0], self._strides[1]]
130127
else:
131128
raise Exception("data_format should be either channels_last or channels_first")
132129

133-
self.filter_shape = (self.kernel_size[0], self.kernel_size[1], self.in_channels, self.depth_multiplier)
130+
self.filter_depthwise = (self.kernel_size[0], self.kernel_size[1], 1, self.in_channels)
131+
self.filter_pointwise = (1, 1, self.in_channels, self.in_channels * self.depth_multiplier)
134132

135-
# Set the size of kernel as (K1,K2), then the shape is (K,Cin,K1,K2), K must be 1.
136-
if BACKEND == 'mindspore':
137-
self.filter_shape = (self.kernel_size[0], self.kernel_size[1], self.in_channels, 1)
138-
139-
if BACKEND in ['tensorflow', 'mindspore']:
140-
self.filters = self._get_weights("filters", shape=self.filter_shape, init=self.W_init, transposed=True)
141-
self.point_filter = None
142-
# TODO The number of parameters on multiple backends is not equal.
143-
# TODO It might be better to use deepwise convolution and pointwise convolution for other backends as well.
144-
if BACKEND in ['paddle', 'torch']:
145-
self.filter_depthwise = (self.in_channels, 1, self.kernel_size[0], self.kernel_size[1])
146-
self.filter_pointwise = (self.in_channels * self.depth_multiplier, self.in_channels, 1, 1)
147-
self.filters = self._get_weights("filters", shape=self.filter_depthwise, init=self.W_init, order=True)
148-
self.point_filter = self._get_weights("point_filter", shape=self.filter_pointwise, init=self.W_init, order=True)
133+
self.filters = self._get_weights("filters", shape=self.filter_depthwise, init=self.W_init)
134+
self.point_filter = self._get_weights("point_filter", shape=self.filter_pointwise, init=self.W_init)
149135

150136
self.depthwise_conv2d = tlx.ops.DepthwiseConv2d(
151-
strides=self._strides, padding=self.padding, data_format=self.data_format, dilations=self._dilation,
152-
ksize=self.kernel_size, channel_multiplier=self.depth_multiplier
137+
strides=self.stride, padding=self.padding, data_format=self.data_format, dilations=self.dilation,
138+
ksize=self.kernel_size, channel_multiplier=self.depth_multiplier, in_channels=self.in_channels
153139
)
154140

155141
self.b_init_flag = False

tests/layers/test_layers_convolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_layer_n4(self):
157157
self.assertEqual(tlx.get_tensor_shape(self.n4), [self.batch_size, 100, 100, 32])
158158

159159
def test_layer_n5(self):
160-
self.assertEqual(len(self.dwconv2dlayer.all_weights), 2)
160+
self.assertEqual(len(self.dwconv2dlayer.all_weights), 3)
161161
self.assertEqual(tlx.get_tensor_shape(self.n5), [self.batch_size, 100, 100, 64])
162162

163163
def test_layer_n6(self):

0 commit comments

Comments
 (0)