Skip to content

Commit 78147ab

Browse files
committed
Fixed topology build only in validation mode
1 parent c5f0b3a commit 78147ab

File tree

4 files changed

+3
-16
lines changed

4 files changed

+3
-16
lines changed

tensorlayerx/nn/core/core_mindspore.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def build_graph(self, *inputs, **kwargs):
300300
for layer_name, layer in layers:
301301
if isinstance(layer, Module):
302302
layer._build_graph = True
303+
self.set_eval()
303304

304305
outputs = self.forward(*inputs, **kwargs)
305306
self.inputs = inputs

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ def build_graph(self, *inputs, **kwargs):
588588
for layer_name, layer in layers:
589589
if isinstance(layer, Module):
590590
layer._build_graph = True
591+
self.set_eval()
591592

592593
outputs = self.forward(*inputs, **kwargs)
593594
self.inputs = inputs

tensorlayerx/nn/core/core_torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def build_graph(self, *inputs, **kwargs):
186186
for name, layer in self.named_modules():
187187
if isinstance(layer, Module):
188188
layer._build_graph = True
189+
self.set_eval()
189190

190191
outputs = self.forward(*inputs, **kwargs)
191192
self.inputs = inputs

tensorlayerx/nn/layers/lambda_layers.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33

4-
import tensorflow as tf
54
from tensorlayerx import logging
6-
from tensorlayerx.files import utils
75
from tensorlayerx.nn.core import Module
86

97
__all__ = [
@@ -150,19 +148,6 @@ def forward(self, inputs, **kwargs):
150148
self._nodes_fixed = True
151149
return outputs
152150

153-
def get_args(self):
154-
init_args = {}
155-
if isinstance(self.fn, tf.keras.layers.Layer) or isinstance(self.fn, tf.keras.Model):
156-
init_args.update({"layer_type": "keraslayer"})
157-
init_args["fn"] = utils.save_keras_model(self.fn)
158-
init_args["fn_weights"] = None
159-
if len(self._nodes) == 0:
160-
init_args["keras_input_shape"] = []
161-
else:
162-
init_args["keras_input_shape"] = self._nodes[0].in_tensors[0].get_shape().as_list()
163-
else:
164-
init_args = {"layer_type": "normal"}
165-
return init_args
166151

167152

168153
class ElementwiseLambda(Module):
@@ -264,7 +249,6 @@ def build(self, inputs_shape=None):
264249
# the weights of the function are provided when the Lambda layer is constructed
265250
pass
266251

267-
# @tf.function
268252
def forward(self, inputs, **kwargs):
269253

270254
if not isinstance(inputs, list):

0 commit comments

Comments
 (0)