@@ -268,38 +268,20 @@ def _save_standard_weights_dict(net, file_path):
268268
269269def encode_list_name (list_name ):
270270 # TensorFlow weights format: conv1.weight:0, conv1.bias:0
271- # Paddle weights format: conv1.weight , conv1.bias
272- # PyTorch weights format: conv1.W , conv1.W
271+ # Paddle weights format: conv1.weights , conv1.bias
272+ # PyTorch weights format: conv1.weights , conv1.bias
273273 # MindSpore weights format: conv1.weights, conv1.bias
274274 # standard weights format: conv1.weights, conv1.bias
275275
276276 for i in range (len (list_name )):
277277 if tlx .BACKEND == 'tensorflow' :
278278 list_name [i ] = list_name [i ][:- 2 ]
279- if tlx .BACKEND == 'torch' :
280- if list_name [i ][- 1 ] == 'W' and 'conv' not in list_name [i ]:
281- list_name [i ] = list_name [i ][:- 2 ] + str ('/weights' )
282- elif list_name [i ][- 1 ] == 'W' and 'conv' in list_name [i ]:
283- list_name [i ] = list_name [i ][:- 2 ] + str ('/filters' )
284- elif list_name [i ][- 1 ] == 'b' :
285- list_name [i ] = list_name [i ][:- 2 ] + str ('/biases' )
286- elif list_name [i ].split ('.' )[- 1 ] in ['beta' , 'gamma' , 'moving_mean' , 'moving_var' ]:
287- pass
288- else :
289- raise NotImplementedError ('This weights cannot be converted.' )
290279 return list_name
291280
292281
293282def decode_key_name (key_name ):
294283 if tlx .BACKEND == 'tensorflow' :
295284 key_name = key_name + str (':0' )
296- if tlx .BACKEND == 'torch' :
297- if key_name .split ('/' )[- 1 ] in ['weights' , 'filters' ]:
298- key_name = key_name [:- 8 ] + str ('.W' )
299- elif key_name .split ('/' )[- 1 ] == 'biases' :
300- key_name = key_name [:- 7 ] + str ('.b' )
301- else :
302- raise NotImplementedError ('This weights cannot be converted.' )
303285 return key_name
304286
305287
@@ -347,11 +329,30 @@ def save_standard_npz_dict(save_list=None, name='model.npz'):
347329 logging .info ("[*] Model saved in npz_dict %s" % name )
348330
349331
350- def _load_standard_weights_dict (net , file_path , skip = False , reshape = False , format = 'npz_dict' ):
351- if format == 'npz_dict' :
352- load_and_assign_standard_npz_dict (net , file_path , skip , reshape )
353- elif format == 'npz' :
354- load_and_assign_standard_npz (file_path , net , reshape )
332+ def _load_standard_weights_dict (net , file_path , skip = False , weights_from = 'tensorflow' , weights_to = 'tensorflow' ):
333+ """
334+
335+ Parameters
336+ ----------
337+ file_path : str
338+ Name of the saved file.
339+ skip : boolean
340+ If 'skip' == True, loaded layer whose name is not found in 'layers' will be skipped. If 'skip' is False,
341+ error will be raised when mismatch is found. Default False.
342+ weights_from : string
343+ The weights file is saved by which framework training. It has to be one of tensorflow,mindspore,paddle or torch.
344+ weights_to : string
345+ Which framework the weights file imports.It has to be one of tensorflow,mindspore,paddle or torch.
346+ """
347+ if weights_from == weights_to :
348+ reshape = False
349+ if weights_from == 'tensorflow' and weights_to != 'tensorflow' :
350+ reshape = True
351+ if weights_from != 'tensorflow' and weights_to == 'tensorflow' :
352+ reshape = True
353+ if weights_from != 'tensorflow' and weights_to != 'tensorflow' :
354+ reshape = False
355+ load_and_assign_standard_npz_dict (net , file_path , skip , reshape )
355356
356357
357358def load_and_assign_standard_npz_dict (net , file_path , skip = False , reshape = False ):
@@ -382,101 +383,96 @@ def load_and_assign_standard_npz_dict(net, file_path, skip=False, reshape=False)
382383 else :
383384 if tlx .BACKEND == 'tensorflow' :
384385 reshape_weights = weight_reshape (weights [key ], reshape )
385- check_reshape (reshape_weights , net .all_weights [net_weights_name .index (de_key )])
386+ # check_reshape(reshape_weights, net.all_weights[net_weights_name.index(de_key)])
386387 utils .assign_tf_variable (net .all_weights [net_weights_name .index (de_key )], reshape_weights )
387388 elif tlx .BACKEND == 'mindspore' :
388389 reshape_weights = weight_reshape (weights [key ], reshape )
389- import mindspore as ms
390390 assign_param = ms .Tensor (reshape_weights , dtype = ms .float32 )
391- check_reshape (assign_param , net .all_weights [net_weights_name .index (de_key )])
391+ # check_reshape(assign_param, net.all_weights[net_weights_name.index(de_key)])
392392 utils .assign_ms_variable (net .all_weights [net_weights_name .index (de_key )], assign_param )
393393 elif tlx .BACKEND == 'paddle' :
394394 reshape_weights = weight_reshape (weights [key ], reshape )
395- check_reshape (reshape_weights , net .all_weights [net_weights_name .index (de_key )])
395+ # check_reshape(reshape_weights, net.all_weights[net_weights_name.index(de_key)])
396396 utils .assign_pd_variable (net .all_weights [net_weights_name .index (de_key )], reshape_weights )
397397 elif tlx .BACKEND == 'torch' :
398398 reshape_weights = weight_reshape (weights [key ], reshape )
399- check_reshape (reshape_weights , net .all_weights [net_weights_name .index (de_key )])
399+ # check_reshape(reshape_weights, net.all_weights[net_weights_name.index(de_key)])
400400 utils .assign_th_variable (torch_weights_dict [de_key ], reshape_weights )
401401 else :
402402 raise NotImplementedError ('Not implemented' )
403403
404404 logging .info ("[*] Model restored from npz_dict %s" % file_path )
405405
406406
407- def load_and_assign_standard_npz (file_path = None , network = None , reshape = False ):
408- if network is None :
409- raise ValueError ("network is None." )
410-
411- if not os .path .exists (file_path ):
412- logging .error ("file {} doesn't exist." .format (file_path ))
413- return False
414- else :
415- weights = utils .load_npz (name = file_path )
416- ops = []
417- if tlx .BACKEND == 'tensorflow' :
418- for idx , param in enumerate (weights ):
419- param = weight_reshape (param , reshape )
420- check_reshape (param , network .all_weights [idx ])
421- ops .append (network .all_weights [idx ].assign (param ))
422-
423- elif tlx .BACKEND == 'mindspore' :
424-
425- class Assign_net (Cell ):
426-
427- def __init__ (self , y ):
428- super (Assign_net , self ).__init__ ()
429- self .y = y
430-
431- def construct (self , x ):
432- Assign ()(self .y , x )
433-
434- for idx , param in enumerate (weights ):
435- assign_param = Tensor (param , dtype = ms .float32 )
436- assign_param = weight_reshape (assign_param , reshape )
437- check_reshape (assign_param , network .all_weights [idx ])
438- Assign ()(network .all_weights [idx ], assign_param )
439-
440- elif tlx .BACKEND == 'paddle' :
441- for idx , param in enumerate (weights ):
442- param = weight_reshape (param , reshape )
443- check_reshape (param , network .all_weights [idx ])
444- utils .assign_pd_variable (network .all_weights [idx ], param )
445-
446- elif tlx .BACKEND == 'torch' :
447- for idx , param in enumerate (weights ):
448- param = weight_reshape (param , reshape )
449- check_reshape (param , network .all_weights [idx ])
450- utils .assign_th_variable (network .all_weights [idx ], param )
451- else :
452- raise NotImplementedError ("This backend is not supported" )
453- return ops
454-
455- logging .info ("[*] Load {} SUCCESS!" .format (file_path ))
456-
457-
458- def check_reshape (weight , shape_weights ):
459- if len (weight .shape ) >= 4 and weight .shape [::- 1 ] == tuple (shape_weights .shape ):
460- if tlx .BACKEND == 'tensorflow' :
461-
462- raise Warning (
463- 'Set reshape to True only when importing weights from MindSpore/PyTorch/PaddlePaddle to TensorFlow.'
464- )
465- if tlx .BACKEND == 'torch' :
466- raise Warning ('Set reshape to True only when importing weights from TensorFlow to PyTorch.' )
467- if tlx .BACKEND == 'paddle' :
468- raise Warning ('Set reshape to True only when importing weights from TensorFlow to PaddlePaddle.' )
469- if tlx .BACKEND == 'mindspore' :
470- raise Warning ('Set reshape to True only when importing weights from TensorFlow to MindSpore.' )
407+ # def load_and_assign_standard_npz(file_path=None, network=None, reshape=False):
408+ # if network is None:
409+ # raise ValueError("network is None.")
410+ #
411+ # if not os.path.exists(file_path):
412+ # logging.error("file {} doesn't exist.".format(file_path))
413+ # return False
414+ # else:
415+ # weights = utils.load_npz(name=file_path)
416+ # ops = []
417+ # if tlx.BACKEND == 'tensorflow':
418+ # for idx, param in enumerate(weights):
419+ # param = weight_reshape(param, reshape)
420+ # check_reshape(param, network.all_weights[idx])
421+ # ops.append(network.all_weights[idx].assign(param))
422+ #
423+ # elif tlx.BACKEND == 'mindspore':
424+ # for idx, param in enumerate(weights):
425+ # assign_param = Tensor(param, dtype=ms.float32)
426+ # assign_param = weight_reshape(assign_param, reshape)
427+ # check_reshape(assign_param, network.all_weights[idx])
428+ # utils.assign_ms_variable(network.all_weights[idx], assign_param)
429+ #
430+ # elif tlx.BACKEND == 'paddle':
431+ # for idx, param in enumerate(weights):
432+ # param = weight_reshape(param, reshape)
433+ # check_reshape(param, network.all_weights[idx])
434+ # utils.assign_pd_variable(network.all_weights[idx], param)
435+ #
436+ # elif tlx.BACKEND == 'torch':
437+ # for idx, param in enumerate(weights):
438+ # param = weight_reshape(param, reshape)
439+ # check_reshape(param, network.all_weights[idx])
440+ # utils.assign_th_variable(network.all_weights[idx], param)
441+ # else:
442+ # raise NotImplementedError("This backend is not supported")
443+ # return ops
444+ #
445+ # logging.info("[*] Load {} SUCCESS!".format(file_path))
446+
447+
448+ # def check_reshape(weight, shape_weights):
449+ # if len(weight.shape) >= 4 and weight.shape[::-1] == tuple(shape_weights.shape):
450+ # if tlx.BACKEND == 'tensorflow':
451+ #
452+ # raise Warning(
453+ # 'Set reshape to True only when importing weights from MindSpore/PyTorch/PaddlePaddle to TensorFlow.'
454+ # )
455+ # if tlx.BACKEND == 'torch':
456+ # raise Warning('Set reshape to True only when importing weights from TensorFlow to PyTorch.')
457+ # if tlx.BACKEND == 'paddle':
458+ # raise Warning('Set reshape to True only when importing weights from TensorFlow to PaddlePaddle.')
459+ # if tlx.BACKEND == 'mindspore':
460+ # raise Warning('Set reshape to True only when importing weights from TensorFlow to MindSpore.')
471461
472462
473463def weight_reshape (weight , reshape = False ):
474464 # TODO In this case only 2D convolution is considered. 3D convolution tests need to be supplemented.
475465 if reshape :
476466 if len (weight .shape ) == 4 :
477- weight = np .moveaxis (weight , (2 , 3 ), (1 , 0 ))
467+ if tlx .BACKEND == 'tensorflow' :
468+ weight = np .moveaxis (weight , (1 , 0 ), (2 , 3 ))
469+ else :
470+ weight = np .moveaxis (weight , (2 , 3 ), (1 , 0 ))
478471 if len (weight .shape ) == 5 :
479- weight = np .moveaxis (weight , (3 , 4 ), (1 , 0 ))
472+ if tlx .BACKEND == 'tensorflow' :
473+ weight = np .moveaxis (weight , (1 , 0 ), (3 , 4 ))
474+ else :
475+ weight = np .moveaxis (weight , (3 , 4 ), (1 , 0 ))
480476 return weight
481477
482478def tolist (tensors ):
0 commit comments