Skip to content

Commit 559aca4

Browse files
committed
Merge branch 'main' of github.com:tensorlayer/TensorLayerX into main
2 parents 5ff8686 + 3781761 commit 559aca4

File tree

7 files changed

+4
-71
lines changed

7 files changed

+4
-71
lines changed

tensorlayerx/files/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2155,7 +2155,7 @@ def save_npz_dict(save_list=None, name='model.npz'):
21552155
save_list_var = []
21562156
for named, values in save_list:
21572157
save_list_names.append(named)
2158-
save_list_var.append(values.detach().numpy())
2158+
save_list_var.append(values.cpu().detach().numpy())
21592159
else:
21602160
raise NotImplementedError('Not implemented')
21612161
save_var_dict = {save_list_names[idx]: val for idx, val in enumerate(save_list_var)}
@@ -2678,7 +2678,7 @@ def th_variables_to_numpy(variables):
26782678
var_list = [variables]
26792679
else:
26802680
var_list = variables
2681-
results = [v.detach().numpy() for v in var_list]
2681+
results = [v.cpu().detach().numpy() for v in var_list]
26822682
return results
26832683

26842684

tensorlayerx/nn/core/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def save_standard_npz_dict(save_list=None, name='model.npz'):
330330
save_list_var = []
331331
for named, values in save_list:
332332
save_list_names.append(named)
333-
save_list_var.append(values.detach().numpy())
333+
save_list_var.append(values.cpu().detach().numpy())
334334
else:
335335
raise NotImplementedError('Not implemented')
336336

tensorlayerx/nn/layers/convolution/depthwise_conv.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,10 @@ def __repr__(self):
120120

121121
def build(self, inputs_shape):
122122
if self.data_format == 'channels_last':
123-
self.data_format = 'NHWC'
124123
if self.in_channels is None:
125124
self.in_channels = inputs_shape[-1]
126125
self._strides = [1, self._strides[0], self._strides[1], 1]
127126
elif self.data_format == 'channels_first':
128-
self.data_format = 'NCHW'
129127
if self.in_channels is None:
130128
self.in_channels = inputs_shape[1]
131129
self._strides = [1, 1, self._strides[0], self._strides[1]]

tensorlayerx/nn/layers/convolution/group_conv.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,11 @@ def __repr__(self):
109109

110110
def build(self, inputs_shape):
111111
if self.data_format == 'channels_last':
112-
self.data_format = 'NHWC'
113112
if self.in_channels is None:
114113
self.in_channels = inputs_shape[-1]
115114
self._stride = [1, self._stride[0], self._stride[1], 1]
116115
self._dilation_rate = [1, self._dilation_rate[0], self._dilation_rate[1], 1]
117116
elif self.data_format == 'channels_first':
118-
self.data_format = 'NCHW'
119117
if self.in_channels is None:
120118
self.in_channels = inputs_shape[1]
121119
self._stride = [1, 1, self._stride[0], self._stride[1]]

tensorlayerx/nn/layers/convolution/mask_conv.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,11 @@ def __repr__(self):
113113

114114
def build(self, inputs_shape):
115115
if self.data_format == 'channels_last':
116-
self._data_format = 'NDHWC'
117116
if self.in_channels is None:
118117
self.in_channels = inputs_shape[-1]
119118
self._strides = [1, self.stride[0], self.stride[1], self.stride[2], 1]
120119
self._dilation_rate = [1, self.dilation[0], self.dilation[1], self.dilation[2], 1]
121120
elif self.data_format == 'channels_first':
122-
self._data_format = 'NCDHW'
123121
if self.in_channels is None:
124122
self.in_channels = inputs_shape[1]
125123
self._strides = [1, 1, self.stride[0], self.stride[1], self.stride[2]]

tensorlayerx/nn/layers/convolution/separable_conv.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,9 @@ def __repr__(self):
103103

104104
def build(self, inputs_shape):
105105
if self.data_format == 'channels_last':
106-
self.data_format = 'NWC'
107106
if self.in_channels is None:
108107
self.in_channels = inputs_shape[-1]
109108
elif self.data_format == 'channels_first':
110-
self.data_format = 'NCW'
111109
if self.in_channels is None:
112110
self.in_channels = inputs_shape[1]
113111
else:

tensorlayerx/nn/layers/pooling.py

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,6 @@ def __repr__(self):
159159

160160
def build(self, inputs_shape=None):
161161
# https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/nn/pool
162-
if self.data_format == 'channels_last':
163-
self.data_format = 'NWC'
164-
elif self.data_format == 'channels_first':
165-
self.data_format = 'NCW'
166-
else:
167-
raise Exception("unsupported data format")
168162
self._filter_size = [self.kernel_size]
169163
self._stride = [self.stride]
170164
self.max_pool = tlx.ops.MaxPool1d(
@@ -238,12 +232,6 @@ def __repr__(self):
238232

239233
def build(self, inputs_shape=None):
240234
# https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/nn/pool
241-
if self.data_format == 'channels_last':
242-
self.data_format = 'NWC'
243-
elif self.data_format == 'channels_first':
244-
self.data_format = 'NCW'
245-
else:
246-
raise Exception("unsupported data format")
247235
self._filter_size = [self.kernel_size]
248236
self._stride = [self.stride]
249237
self.avg_pool = tlx.ops.AvgPool1d(
@@ -318,10 +306,8 @@ def __repr__(self):
318306

319307
def build(self, inputs_shape=None):
320308
if self.data_format == 'channels_last':
321-
self.data_format = 'NHWC'
322309
self._stride = [1, self.stride[0], self.stride[1], 1]
323310
elif self.data_format == 'channels_first':
324-
self.data_format = 'NCHW'
325311
self._stride = [1, 1, self.stride[0], self.stride[1]]
326312
else:
327313
raise Exception("unsupported data format")
@@ -398,10 +384,8 @@ def __repr__(self):
398384

399385
def build(self, inputs_shape=None):
400386
if self.data_format == 'channels_last':
401-
self.data_format = 'NHWC'
402387
self._stride = [1, self.stride[0], self.stride[1], 1]
403388
elif self.data_format == 'channels_first':
404-
self.data_format = 'NCHW'
405389
self._stride = [1, 1, self.stride[0], self.stride[1]]
406390
else:
407391
raise Exception("unsupported data format")
@@ -480,10 +464,8 @@ def __repr__(self):
480464

481465
def build(self, inputs_shape=None):
482466
if self.data_format == 'channels_last':
483-
self.data_format = 'NDHWC'
484467
self._stride = [1, self.stride[0], self.stride[1], self.stride[2], 1]
485468
elif self.data_format == 'channels_first':
486-
self.data_format = 'NCDHW'
487469
self._stride = [1, 1, self.stride[0], self.stride[1], self.stride[2]]
488470
else:
489471
raise Exception("unsupported data format")
@@ -562,12 +544,6 @@ def __repr__(self):
562544

563545
def build(self, inputs_shape=None):
564546
self._stride = [1, self.stride[0], self.stride[1], self.stride[2], 1]
565-
if self.data_format == 'channels_last':
566-
self.data_format = 'NDHWC'
567-
elif self.data_format == 'channels_first':
568-
self.data_format = 'NCDHW'
569-
else:
570-
raise Exception("unsupported data format")
571547
self.avg_pool3d = tlx.ops.AvgPool3d(
572548
ksize=self.kernel_size, strides=self._stride, padding=self.padding, data_format=self.data_format
573549
)
@@ -1055,12 +1031,6 @@ def __repr__(self):
10551031
return s.format(classname=self.__class__.__name__, **self.__dict__)
10561032

10571033
def build(self, inputs_shape=None):
1058-
if self.data_format == 'channels_last':
1059-
self.data_format = 'NWC'
1060-
elif self.data_format == 'channels_first':
1061-
self.data_format = 'NCW'
1062-
else:
1063-
raise Exception("unsupported data format")
10641034

10651035
self.adaptivemeanpool1d = tlx.ops.AdaptiveMeanPool1D(output_size=self.output_size, data_format=self.data_format)
10661036

@@ -1113,12 +1083,6 @@ def __repr__(self):
11131083
return s.format(classname=self.__class__.__name__, **self.__dict__)
11141084

11151085
def build(self, inputs_shape=None):
1116-
if self.data_format == 'channels_last':
1117-
self.data_format = 'NHWC'
1118-
elif self.data_format == 'channels_first':
1119-
self.data_format = 'NCHW'
1120-
else:
1121-
raise Exception("unsupported data format")
11221086

11231087
if isinstance(self.output_size, int):
11241088
self.output_size = (self.output_size, ) * 2
@@ -1174,12 +1138,6 @@ def __repr__(self):
11741138
return s.format(classname=self.__class__.__name__, **self.__dict__)
11751139

11761140
def build(self, inputs_shape=None):
1177-
if self.data_format == 'channels_last':
1178-
self.data_format = 'NDHWC'
1179-
elif self.data_format == 'channels_first':
1180-
self.data_format = 'NCDHW'
1181-
else:
1182-
raise Exception("unsupported data format")
11831141

11841142
if isinstance(self.output_size, int):
11851143
self.output_size = (self.output_size, ) * 3
@@ -1235,12 +1193,6 @@ def __repr__(self):
12351193
return s.format(classname=self.__class__.__name__, **self.__dict__)
12361194

12371195
def build(self, inputs_shape=None):
1238-
if self.data_format == 'channels_last':
1239-
self.data_format = 'NWC'
1240-
elif self.data_format == 'channels_first':
1241-
self.data_format = 'NCW'
1242-
else:
1243-
raise Exception("unsupported data format")
12441196

12451197
self.adaptivemaxpool1d = tlx.ops.AdaptiveMaxPool1D(output_size=self.output_size, data_format=self.data_format)
12461198

@@ -1293,12 +1245,7 @@ def __repr__(self):
12931245
return s.format(classname=self.__class__.__name__, **self.__dict__)
12941246

12951247
def build(self, inputs_shape=None):
1296-
if self.data_format == 'channels_last':
1297-
self.data_format = 'NHWC'
1298-
elif self.data_format == 'channels_first':
1299-
self.data_format = 'NCHW'
1300-
else:
1301-
raise Exception("unsupported data format")
1248+
13021249
if isinstance(self.output_size, int):
13031250
self.output_size = (self.output_size, ) * 2
13041251

@@ -1353,12 +1300,6 @@ def __repr__(self):
13531300
return s.format(classname=self.__class__.__name__, **self.__dict__)
13541301

13551302
def build(self, inputs_shape=None):
1356-
if self.data_format == 'channels_last':
1357-
self.data_format = 'NDHWC'
1358-
elif self.data_format == 'channels_first':
1359-
self.data_format = 'NCDHW'
1360-
else:
1361-
raise Exception("unsupported data format")
13621303

13631304
if isinstance(self.output_size, int):
13641305
self.output_size = (self.output_size, ) * 3

0 commit comments

Comments
 (0)