Skip to content

Commit 72da21d

Browse files
committed
Fix save load npz
1 parent b088cb3 commit 72da21d

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

tensorlayerx/files/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,9 +2003,14 @@ def save_npz(save_list=None, name='model.npz'):
20032003
save_list_var = th_variables_to_numpy(save_list)
20042004
else:
20052005
raise NotImplementedError("This backend is not supported")
2006-
np.savez(name, params=save_list_var)
2006+
# Number by length
2007+
save_list_names = [str(i) for i in range(len(save_list_var))]
2008+
save_var_dict = {save_list_names[idx]: val for idx, val in enumerate(save_list_var)}
2009+
np.savez(name, **save_var_dict)
20072010
save_list_var = None
2011+
save_var_dict = None
20082012
del save_list_var
2013+
del save_var_dict
20092014
logging.info("[*] Saved")
20102015

20112016

@@ -2034,7 +2039,7 @@ def load_npz(path='', name='model.npz'):
20342039
20352040
"""
20362041
d = np.load(os.path.join(path, name), allow_pickle=True)
2037-
return d['params']
2042+
return [d[str(i)] for i in range(len(d))]
20382043

20392044

20402045
def assign_params(**kwargs):

tensorlayerx/nn/layers/convolution/depthwise_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __repr__(self):
103103
actstr = self.act.__class__.__name__ if self.act is not None else 'No Activation'
104104
s = (
105105
'{classname}(in_channels={in_channels}, out_channels={n_filter}, kernel_size={kernel_size}'
106-
', strides={strides}, padding={padding}'
106+
', stride={stride}, padding={padding}'
107107
)
108108
if self.dilation != (1, ) * len(self.dilation):
109109
s += ', dilation={dilation}'

0 commit comments

Comments
 (0)