Skip to content

Commit c018283

Browse files
authored
Update TLX training models can be imported into any backend. (#17)
1 parent 57eae1b commit c018283

File tree

6 files changed

+118
-146
lines changed

6 files changed

+118
-146
lines changed

examples/basic_tutorials/tutorial_tensorlayer_model_load.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,14 @@
22
# -*- coding: utf-8 -*-
33

44
import os
5-
# os.environ['TL_BACKEND'] = 'tensorflow'
6-
os.environ['TL_BACKEND'] = 'paddle'
5+
os.environ['TL_BACKEND'] = 'tensorflow'
6+
# os.environ['TL_BACKEND'] = 'paddle'
77
# os.environ['TL_BACKEND'] = 'mindspore'
88
# os.environ['TL_BACKEND'] = 'torch'
99

1010
import tensorlayerx as tlx
1111
from tensorlayerx.nn import Module
1212
from tensorlayerx.nn import Linear, Dropout, Conv2d, MaxPool2d, Flatten
13-
from tensorlayerx.dataflow import Dataset
14-
15-
X_train, y_train, X_val, y_val, X_test, y_test = tlx.files.load_mnist_dataset(shape=(-1, 784))
16-
17-
18-
class mnistdataset(Dataset):
19-
20-
def __init__(self, data=X_train, label=y_train):
21-
self.data = data
22-
self.label = label
23-
24-
def __getitem__(self, index):
25-
data = self.data[index].astype('float32')
26-
label = self.label[index].astype('int64')
27-
return data, label
28-
29-
def __len__(self):
30-
return len(self.data)
31-
3213

3314
class CustomModel(Module):
3415

@@ -92,26 +73,23 @@ def forward(self, x):
9273
return z
9374

9475

95-
# TODO The MLP model was saved to the standard npz_dict format after training at the TensorFlow backend
96-
# and imported into TensorFlow/PyTorch/PaddlePaddle/MindSpore.
76+
# # TODO The MLP model was saved to the standard npz_dict format after training at the TensorFlow backend
77+
# # and imported into TensorFlow/PyTorch/PaddlePaddle/MindSpore.
9778
# MLP = CustomModel()
98-
# MLP.save_standard_weights('./model.npz')
99-
# # MLP.load_standard_weights('./model.npz', skip=True)
79+
# # MLP.save_standard_weights('./model.npz')
80+
# MLP.load_standard_weights('./model.npz', weights_from='tensorflow', weights_to='mindspore')
10081
# MLP.set_eval()
10182
# inputs = tlx.layers.Input(shape=(10, 784))
102-
# print(MLP(inputs))
83+
# output = MLP(inputs)
84+
# print(output)
10385

10486
# TODO The CNN model was saved to the standard npz_dict format after training at the TensorFlow backend
10587
# and imported into TensorFlow/PyTorch/PaddlePaddle/MindSpore.
10688
cnn = CNN()
107-
# cnn.save_standard_weights('./model.npz')
108-
# TODO Tensorflow trained parameters are imported to the TensorFlow backend.
109-
cnn.load_standard_weights('./model.npz', skip=False, reshape=True)
110-
111-
# TODO Tensorflow backend trained parameters imported to PaddlePaddle/PyTorch/MindSpore to
112-
# set reshape to True parameter to convert convolution shape.
113-
# cnn.load_standard_weights('./model.npz', skip=True, reshape=True)
89+
# cnn.save_standard_weights('./cnn.npz')
90+
cnn.load_standard_weights('./cnn.npz', weights_from='torch', weights_to='tensorflow')
11491
cnn.set_eval()
92+
11593
inputs = tlx.nn.Input(shape=(10, 28, 28, 3), dtype=tlx.float32)
11694
outputs = cnn(inputs)
117-
print(outputs)
95+
# print(outputs)

tensorlayerx/nn/core/common.py

Lines changed: 92 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -268,38 +268,20 @@ def _save_standard_weights_dict(net, file_path):
268268

269269
def encode_list_name(list_name):
270270
# TensorFlow weights format: conv1.weight:0, conv1.bias:0
271-
# Paddle weights format: conv1.weight, conv1.bias
272-
# PyTorch weights format: conv1.W, conv1.W
271+
# Paddle weights format: conv1.weights, conv1.bias
272+
# PyTorch weights format: conv1.weights, conv1.bias
273273
# MindSpore weights format: conv1.weights, conv1.bias
274274
# standard weights format: conv1.weights, conv1.bias
275275

276276
for i in range(len(list_name)):
277277
if tlx.BACKEND == 'tensorflow':
278278
list_name[i] = list_name[i][:-2]
279-
if tlx.BACKEND == 'torch':
280-
if list_name[i][-1] == 'W' and 'conv' not in list_name[i]:
281-
list_name[i] = list_name[i][:-2] + str('/weights')
282-
elif list_name[i][-1] == 'W' and 'conv' in list_name[i]:
283-
list_name[i] = list_name[i][:-2] + str('/filters')
284-
elif list_name[i][-1] == 'b':
285-
list_name[i] = list_name[i][:-2] + str('/biases')
286-
elif list_name[i].split('.')[-1] in ['beta', 'gamma', 'moving_mean', 'moving_var']:
287-
pass
288-
else:
289-
raise NotImplementedError('This weights cannot be converted.')
290279
return list_name
291280

292281

293282
def decode_key_name(key_name):
294283
if tlx.BACKEND == 'tensorflow':
295284
key_name = key_name + str(':0')
296-
if tlx.BACKEND == 'torch':
297-
if key_name.split('/')[-1] in ['weights', 'filters']:
298-
key_name = key_name[:-8] + str('.W')
299-
elif key_name.split('/')[-1] == 'biases':
300-
key_name = key_name[:-7] + str('.b')
301-
else:
302-
raise NotImplementedError('This weights cannot be converted.')
303285
return key_name
304286

305287

@@ -347,11 +329,30 @@ def save_standard_npz_dict(save_list=None, name='model.npz'):
347329
logging.info("[*] Model saved in npz_dict %s" % name)
348330

349331

350-
def _load_standard_weights_dict(net, file_path, skip=False, reshape=False, format='npz_dict'):
351-
if format == 'npz_dict':
352-
load_and_assign_standard_npz_dict(net, file_path, skip, reshape)
353-
elif format == 'npz':
354-
load_and_assign_standard_npz(file_path, net, reshape)
332+
def _load_standard_weights_dict(net, file_path, skip=False, weights_from='tensorflow', weights_to='tensorflow'):
333+
"""
334+
335+
Parameters
336+
----------
337+
file_path : str
338+
Name of the saved file.
339+
skip : boolean
340+
If 'skip' == True, loaded layer whose name is not found in 'layers' will be skipped. If 'skip' is False,
341+
error will be raised when mismatch is found. Default False.
342+
weights_from : string
343+
The weights file is saved by which framework training. It has to be one of tensorflow,mindspore,paddle or torch.
344+
weights_to : string
345+
Which framework the weights file imports.It has to be one of tensorflow,mindspore,paddle or torch.
346+
"""
347+
if weights_from == weights_to:
348+
reshape = False
349+
if weights_from == 'tensorflow' and weights_to != 'tensorflow':
350+
reshape = True
351+
if weights_from != 'tensorflow' and weights_to == 'tensorflow':
352+
reshape = True
353+
if weights_from !='tensorflow' and weights_to != 'tensorflow':
354+
reshape = False
355+
load_and_assign_standard_npz_dict(net, file_path, skip, reshape)
355356

356357

357358
def load_and_assign_standard_npz_dict(net, file_path, skip=False, reshape=False):
@@ -382,101 +383,96 @@ def load_and_assign_standard_npz_dict(net, file_path, skip=False, reshape=False)
382383
else:
383384
if tlx.BACKEND == 'tensorflow':
384385
reshape_weights = weight_reshape(weights[key], reshape)
385-
check_reshape(reshape_weights, net.all_weights[net_weights_name.index(de_key)])
386+
# check_reshape(reshape_weights, net.all_weights[net_weights_name.index(de_key)])
386387
utils.assign_tf_variable(net.all_weights[net_weights_name.index(de_key)], reshape_weights)
387388
elif tlx.BACKEND == 'mindspore':
388389
reshape_weights = weight_reshape(weights[key], reshape)
389-
import mindspore as ms
390390
assign_param = ms.Tensor(reshape_weights, dtype=ms.float32)
391-
check_reshape(assign_param, net.all_weights[net_weights_name.index(de_key)])
391+
# check_reshape(assign_param, net.all_weights[net_weights_name.index(de_key)])
392392
utils.assign_ms_variable(net.all_weights[net_weights_name.index(de_key)], assign_param)
393393
elif tlx.BACKEND == 'paddle':
394394
reshape_weights = weight_reshape(weights[key], reshape)
395-
check_reshape(reshape_weights, net.all_weights[net_weights_name.index(de_key)])
395+
# check_reshape(reshape_weights, net.all_weights[net_weights_name.index(de_key)])
396396
utils.assign_pd_variable(net.all_weights[net_weights_name.index(de_key)], reshape_weights)
397397
elif tlx.BACKEND == 'torch':
398398
reshape_weights = weight_reshape(weights[key], reshape)
399-
check_reshape(reshape_weights, net.all_weights[net_weights_name.index(de_key)])
399+
# check_reshape(reshape_weights, net.all_weights[net_weights_name.index(de_key)])
400400
utils.assign_th_variable(torch_weights_dict[de_key], reshape_weights)
401401
else:
402402
raise NotImplementedError('Not implemented')
403403

404404
logging.info("[*] Model restored from npz_dict %s" % file_path)
405405

406406

407-
def load_and_assign_standard_npz(file_path=None, network=None, reshape=False):
408-
if network is None:
409-
raise ValueError("network is None.")
410-
411-
if not os.path.exists(file_path):
412-
logging.error("file {} doesn't exist.".format(file_path))
413-
return False
414-
else:
415-
weights = utils.load_npz(name=file_path)
416-
ops = []
417-
if tlx.BACKEND == 'tensorflow':
418-
for idx, param in enumerate(weights):
419-
param = weight_reshape(param, reshape)
420-
check_reshape(param, network.all_weights[idx])
421-
ops.append(network.all_weights[idx].assign(param))
422-
423-
elif tlx.BACKEND == 'mindspore':
424-
425-
class Assign_net(Cell):
426-
427-
def __init__(self, y):
428-
super(Assign_net, self).__init__()
429-
self.y = y
430-
431-
def construct(self, x):
432-
Assign()(self.y, x)
433-
434-
for idx, param in enumerate(weights):
435-
assign_param = Tensor(param, dtype=ms.float32)
436-
assign_param = weight_reshape(assign_param, reshape)
437-
check_reshape(assign_param, network.all_weights[idx])
438-
Assign()(network.all_weights[idx], assign_param)
439-
440-
elif tlx.BACKEND == 'paddle':
441-
for idx, param in enumerate(weights):
442-
param = weight_reshape(param, reshape)
443-
check_reshape(param, network.all_weights[idx])
444-
utils.assign_pd_variable(network.all_weights[idx], param)
445-
446-
elif tlx.BACKEND == 'torch':
447-
for idx, param in enumerate(weights):
448-
param = weight_reshape(param, reshape)
449-
check_reshape(param, network.all_weights[idx])
450-
utils.assign_th_variable(network.all_weights[idx], param)
451-
else:
452-
raise NotImplementedError("This backend is not supported")
453-
return ops
454-
455-
logging.info("[*] Load {} SUCCESS!".format(file_path))
456-
457-
458-
def check_reshape(weight, shape_weights):
459-
if len(weight.shape) >= 4 and weight.shape[::-1] == tuple(shape_weights.shape):
460-
if tlx.BACKEND == 'tensorflow':
461-
462-
raise Warning(
463-
'Set reshape to True only when importing weights from MindSpore/PyTorch/PaddlePaddle to TensorFlow.'
464-
)
465-
if tlx.BACKEND == 'torch':
466-
raise Warning('Set reshape to True only when importing weights from TensorFlow to PyTorch.')
467-
if tlx.BACKEND == 'paddle':
468-
raise Warning('Set reshape to True only when importing weights from TensorFlow to PaddlePaddle.')
469-
if tlx.BACKEND == 'mindspore':
470-
raise Warning('Set reshape to True only when importing weights from TensorFlow to MindSpore.')
407+
# def load_and_assign_standard_npz(file_path=None, network=None, reshape=False):
408+
# if network is None:
409+
# raise ValueError("network is None.")
410+
#
411+
# if not os.path.exists(file_path):
412+
# logging.error("file {} doesn't exist.".format(file_path))
413+
# return False
414+
# else:
415+
# weights = utils.load_npz(name=file_path)
416+
# ops = []
417+
# if tlx.BACKEND == 'tensorflow':
418+
# for idx, param in enumerate(weights):
419+
# param = weight_reshape(param, reshape)
420+
# check_reshape(param, network.all_weights[idx])
421+
# ops.append(network.all_weights[idx].assign(param))
422+
#
423+
# elif tlx.BACKEND == 'mindspore':
424+
# for idx, param in enumerate(weights):
425+
# assign_param = Tensor(param, dtype=ms.float32)
426+
# assign_param = weight_reshape(assign_param, reshape)
427+
# check_reshape(assign_param, network.all_weights[idx])
428+
# utils.assign_ms_variable(network.all_weights[idx], assign_param)
429+
#
430+
# elif tlx.BACKEND == 'paddle':
431+
# for idx, param in enumerate(weights):
432+
# param = weight_reshape(param, reshape)
433+
# check_reshape(param, network.all_weights[idx])
434+
# utils.assign_pd_variable(network.all_weights[idx], param)
435+
#
436+
# elif tlx.BACKEND == 'torch':
437+
# for idx, param in enumerate(weights):
438+
# param = weight_reshape(param, reshape)
439+
# check_reshape(param, network.all_weights[idx])
440+
# utils.assign_th_variable(network.all_weights[idx], param)
441+
# else:
442+
# raise NotImplementedError("This backend is not supported")
443+
# return ops
444+
#
445+
# logging.info("[*] Load {} SUCCESS!".format(file_path))
446+
447+
448+
# def check_reshape(weight, shape_weights):
449+
# if len(weight.shape) >= 4 and weight.shape[::-1] == tuple(shape_weights.shape):
450+
# if tlx.BACKEND == 'tensorflow':
451+
#
452+
# raise Warning(
453+
# 'Set reshape to True only when importing weights from MindSpore/PyTorch/PaddlePaddle to TensorFlow.'
454+
# )
455+
# if tlx.BACKEND == 'torch':
456+
# raise Warning('Set reshape to True only when importing weights from TensorFlow to PyTorch.')
457+
# if tlx.BACKEND == 'paddle':
458+
# raise Warning('Set reshape to True only when importing weights from TensorFlow to PaddlePaddle.')
459+
# if tlx.BACKEND == 'mindspore':
460+
# raise Warning('Set reshape to True only when importing weights from TensorFlow to MindSpore.')
471461

472462

473463
def weight_reshape(weight, reshape=False):
474464
# TODO In this case only 2D convolution is considered. 3D convolution tests need to be supplemented.
475465
if reshape:
476466
if len(weight.shape) == 4:
477-
weight = np.moveaxis(weight, (2, 3), (1, 0))
467+
if tlx.BACKEND == 'tensorflow':
468+
weight = np.moveaxis(weight, (1, 0), (2, 3))
469+
else:
470+
weight = np.moveaxis(weight, (2, 3), (1, 0))
478471
if len(weight.shape) == 5:
479-
weight = np.moveaxis(weight, (3, 4), (1, 0))
472+
if tlx.BACKEND == 'tensorflow':
473+
weight = np.moveaxis(weight, (1, 0), (3, 4))
474+
else:
475+
weight = np.moveaxis(weight, (3, 4), (1, 0))
480476
return weight
481477

482478
def tolist(tensors):

tensorlayerx/nn/core/core_mindspore.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def load_weights(self, file_path, format=None, in_order=True, skip=False):
117117
def save_standard_weights(self, file_path):
118118
_save_standard_weights_dict(self, file_path)
119119

120-
def load_standard_weights(self, file_path, skip=False, reshape=False, format='npz_dict'):
121-
_load_standard_weights_dict(self, file_path, skip, reshape, format)
120+
def load_standard_weights(self, file_path, weights_from, weights_to, skip=False):
121+
_load_standard_weights_dict(self, file_path, skip=skip, weights_from=weights_from, weights_to=weights_to)
122122

123123
@staticmethod
124124
def _compute_shape(tensors):
@@ -158,6 +158,8 @@ def set_eval(self):
158158
"""
159159
self._phase = 'predict'
160160
self.add_flags_recursive(training=False)
161+
for layer in self.cells():
162+
layer.is_train = False
161163
return self
162164

163165
def test(self):
@@ -760,7 +762,6 @@ def update(self, parameters):
760762
"ParameterDict update sequence element "
761763
"#" + str(j) + " should be Iterable; is" + type(p).__name__
762764
)
763-
print(p)
764765
if not len(p) == 2:
765766
raise ValueError(
766767
"ParameterDict update sequence element "

tensorlayerx/nn/core/core_paddle.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ def load_weights(self, file_path, format=None, in_order=True, skip=False):
270270
def save_standard_weights(self, file_path):
271271
_save_standard_weights_dict(self, file_path)
272272

273-
def load_standard_weights(self, file_path, skip=False, reshape=False, format='npz_dict'):
274-
_load_standard_weights_dict(self, file_path, skip, reshape, format)
273+
def load_standard_weights(self, file_path, weights_from, weights_to, skip=False):
274+
_load_standard_weights_dict(self, file_path, skip=skip, weights_from=weights_from, weights_to=weights_to)
275275

276276
def str_to_init(self, initializer):
277277
return str2init(initializer)
@@ -730,7 +730,6 @@ def update(self, parameters):
730730
"ParameterDict update sequence element "
731731
"#" + str(j) + " should be Iterable; is" + type(p).__name__
732732
)
733-
print(p)
734733
if not len(p) == 2:
735734
raise ValueError(
736735
"ParameterDict update sequence element "

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def save_standard_weights(self, file_path):
221221

222222
_save_standard_weights_dict(self, file_path)
223223

224-
def load_standard_weights(self, file_path, skip=False, reshape=False, format='npz_dict'):
224+
def load_standard_weights(self, file_path, weights_from, weights_to, skip=False):
225225
"""
226226
227227
Parameters
@@ -231,14 +231,13 @@ def load_standard_weights(self, file_path, skip=False, reshape=False, format='np
231231
skip : boolean
232232
If 'skip' == True, loaded layer whose name is not found in 'layers' will be skipped. If 'skip' is False,
233233
error will be raised when mismatch is found. Default False.
234-
reshape : boolean
235-
This parameter needs to be set to True when importing parameters from tensorflow training to paddle/mindspore/pytorch,
236-
and similarly when importing parameters from paddle/mindspore/pytorch training to tensorflow.
237-
This parameter does not need to be set between paddle/mindspore/pytorch.
238-
234+
weights_from : string
235+
The weights file is saved by which framework training. It has to be one of tensorflow,mindspore,paddle or torch.
236+
weights_to : string
237+
Which framework the weights file imports.It has to be one of tensorflow,mindspore,paddle or torch.
239238
"""
240239

241-
_load_standard_weights_dict(self, file_path, skip, reshape, format)
240+
_load_standard_weights_dict(self, file_path, skip=skip, weights_from=weights_from, weights_to=weights_to)
242241

243242
def _set_mode_for_layers(self, is_train):
244243
"""Set all layers of this network to a given mode.

0 commit comments

Comments
 (0)