Skip to content

Commit e8b2de4

Browse files
committed
Fix model save
1 parent 5006403 commit e8b2de4

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

examples/basic_tutorials/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@ TensorLayerX provides Sequential and Subclass define a model. Sequential model a
2525
## Others
2626
- [Interoperability of Different Back-end Models](https://github.com/tensorlayer/TensorLayerX/blob/main/examples/basic_tutorials/tutorial_tensorlayer_model_load.py)
2727
- [TensorFlow Model Save to pb](https://github.com/tensorlayer/TensorLayerX/blob/main/examples/basic_tutorials/tensorflow_model_save_to_pb.py)
28-
- [Gradient Clip](https://github.com/tensorlayer/TensorLayerX/blob/main/examples/basic_tutorials/gradient_clip_mixed_tensorflow.py)
28+
- [Gradient Clip](https://github.com/tensorlayer/TensorLayerX/blob/main/examples/basic_tutorials/gradient_clip_mixed_tensorflow.py)
29+
- [Using tensorboard](https://github.com/tensorlayer/TensorLayerX/blob/main/examples/basic_tutorials/tutorial_using_tensorboradX.py)

tensorlayerx/model/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def save_standard_weights(self, file_path):
147147
148148
"""
149149

150-
_save_standard_weights_dict(self, file_path)
150+
_save_standard_weights_dict(self.network, file_path)
151151

152152
def load_standard_weights(self, file_path, skip=False, reshape=False, format='npz_dict'):
153153
"""
@@ -166,7 +166,7 @@ def load_standard_weights(self, file_path, skip=False, reshape=False, format='np
166166
167167
"""
168168

169-
_load_standard_weights_dict(self, file_path, skip, reshape, format)
169+
_load_standard_weights_dict(self.network, file_path, skip, reshape, format)
170170

171171
def save_weights(self, file_path, format=None):
172172
"""Input file_path, save model weights into a file of given format.
@@ -206,7 +206,7 @@ def save_weights(self, file_path, format=None):
206206
207207
"""
208208

209-
_save_weights(net=self, file_path=file_path, format=format)
209+
_save_weights(net=self.network, file_path=file_path, format=format)
210210

211211
def load_weights(self, file_path, format=None, in_order=True, skip=False):
212212
"""Load model weights from a given file, which should be previously saved by self.save_weights().
@@ -259,7 +259,7 @@ def load_weights(self, file_path, format=None, in_order=True, skip=False):
259259
260260
"""
261261

262-
_load_weights(net=self, file_path=file_path, format=format, in_order=in_order, skip=skip)
262+
_load_weights(net=self.network, file_path=file_path, format=format, in_order=in_order, skip=skip)
263263

264264
def tf_train(
265265
self, n_epoch, train_dataset, network, loss_fn, train_weights, optimizer, metrics, print_train_batch,

tensorlayerx/nn/layers/convolution/simplified_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def build(self, inputs_shape):
265265
raise Exception("data_format should be either channels_last or channels_first")
266266

267267
#TODO channels first filter shape [out_channel, in_channel, filter_h, filter_w]
268-
self.filter_shape = (self.kernel_size[0], self.kernel_size[1], self.in_channels, self.out_channels )
268+
self.filter_shape = (self.kernel_size[0], self.kernel_size[1], self.in_channels, self.out_channels)
269269
self.W = self._get_weights("filters", shape=self.filter_shape, init=self.W_init)
270270

271271
self.b_init_flag = False

0 commit comments

Comments
 (0)