|
4 | 4 | import tensorlayerx as tlx |
5 | 5 | from tensorlayerx import logging |
6 | 6 | from tensorlayerx.nn.core import Module |
7 | | -from tensorlayerx.backend import BACKEND |
8 | 7 |
|
9 | 8 | __all__ = [ |
10 | 9 | 'DepthwiseConv2d', |
@@ -80,9 +79,9 @@ def __init__( |
80 | 79 | ): |
81 | 80 | super().__init__(name, act=act) |
82 | 81 | self.kernel_size = self.check_param(kernel_size) |
83 | | - self.stride = self._strides = self.check_param(stride) |
| 82 | + self.stride = self.check_param(stride) |
84 | 83 | self.padding = padding |
85 | | - self.dilation = self._dilation = self.check_param(dilation) |
| 84 | + self.dilation = self.check_param(dilation) |
86 | 85 | self.data_format = data_format |
87 | 86 | self.depth_multiplier = depth_multiplier |
88 | 87 | self.W_init = self.str_to_init(W_init) |
@@ -122,34 +121,21 @@ def build(self, inputs_shape): |
122 | 121 | if self.data_format == 'channels_last': |
123 | 122 | if self.in_channels is None: |
124 | 123 | self.in_channels = inputs_shape[-1] |
125 | | - self._strides = [1, self._strides[0], self._strides[1], 1] |
126 | 124 | elif self.data_format == 'channels_first': |
127 | 125 | if self.in_channels is None: |
128 | 126 | self.in_channels = inputs_shape[1] |
129 | | - self._strides = [1, 1, self._strides[0], self._strides[1]] |
130 | 127 | else: |
131 | 128 | raise Exception("data_format should be either channels_last or channels_first") |
132 | 129 |
|
133 | | - self.filter_shape = (self.kernel_size[0], self.kernel_size[1], self.in_channels, self.depth_multiplier) |
| 130 | + self.filter_depthwise = (self.kernel_size[0], self.kernel_size[1], 1, self.in_channels) |
| 131 | + self.filter_pointwise = (1, 1, self.in_channels, self.in_channels * self.depth_multiplier) |
134 | 132 |
|
135 | | - # Set the size of kernel as (K1,K2), then the shape is (K,Cin,K1,K2), K must be 1. |
136 | | - if BACKEND == 'mindspore': |
137 | | - self.filter_shape = (self.kernel_size[0], self.kernel_size[1], self.in_channels, 1) |
138 | | - |
139 | | - if BACKEND in ['tensorflow', 'mindspore']: |
140 | | - self.filters = self._get_weights("filters", shape=self.filter_shape, init=self.W_init, transposed=True) |
141 | | - self.point_filter = None |
142 | | - # TODO The number of parameters on multiple backends is not equal. |
143 | | - # TODO It might be better to use deepwise convolution and pointwise convolution for other backends as well. |
144 | | - if BACKEND in ['paddle', 'torch']: |
145 | | - self.filter_depthwise = (self.in_channels, 1, self.kernel_size[0], self.kernel_size[1]) |
146 | | - self.filter_pointwise = (self.in_channels * self.depth_multiplier, self.in_channels, 1, 1) |
147 | | - self.filters = self._get_weights("filters", shape=self.filter_depthwise, init=self.W_init, order=True) |
148 | | - self.point_filter = self._get_weights("point_filter", shape=self.filter_pointwise, init=self.W_init, order=True) |
| 133 | + self.filters = self._get_weights("filters", shape=self.filter_depthwise, init=self.W_init) |
| 134 | + self.point_filter = self._get_weights("point_filter", shape=self.filter_pointwise, init=self.W_init) |
149 | 135 |
|
150 | 136 | self.depthwise_conv2d = tlx.ops.DepthwiseConv2d( |
151 | | - strides=self._strides, padding=self.padding, data_format=self.data_format, dilations=self._dilation, |
152 | | - ksize=self.kernel_size, channel_multiplier=self.depth_multiplier |
| 137 | + strides=self.stride, padding=self.padding, data_format=self.data_format, dilations=self.dilation, |
| 138 | + ksize=self.kernel_size, channel_multiplier=self.depth_multiplier, in_channels=self.in_channels |
153 | 139 | ) |
154 | 140 |
|
155 | 141 | self.b_init_flag = False |
|
0 commit comments