Skip to content

Commit 5043b0f

Browse files
committed
Fix TrainOneStep
1 parent f09bb58 commit 5043b0f

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

tensorlayerx/model/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,8 +560,8 @@ def __init__(self, net_with_loss, optimizer, train_weights):
560560
else:
561561
raise NotImplementedError("This backend is not supported")
562562

563-
def __call__(self, data, label):
564-
loss = self.net_with_train(data, label)
563+
def __call__(self, data, label, *args, **kwargs):
564+
loss = self.net_with_train(data, label, *args, **kwargs)
565565
return loss
566566

567567

tensorlayerx/model/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ def __init__(self, net_with_loss, optimizer, train_weights):
135135
self.optimizer = optimizer
136136
self.train_weights = train_weights
137137

138-
def __call__(self, data, label):
138+
def __call__(self, data, label, *args, **kwargs):
139139
with tf.GradientTape() as tape:
140-
loss = self.net_with_loss(data, label)
140+
loss = self.net_with_loss(data, label, *args, **kwargs)
141141
grad = tape.gradient(loss, self.train_weights)
142142
self.optimizer.apply_gradients(zip(grad, self.train_weights))
143143
return loss.numpy()
@@ -152,8 +152,8 @@ def __init__(self, net_with_loss, optimizer, train_weights):
152152
self.net_with_loss = net_with_loss
153153
self.train_network = GradWrap(net_with_loss, train_weights)
154154

155-
def __call__(self, data, label):
156-
loss = self.net_with_loss(data, label)
155+
def __call__(self, data, label, *args, **kwargs):
156+
loss = self.net_with_loss(data, label, *args, **kwargs)
157157
grads = self.train_network(data, label)
158158
self.optimizer.apply_gradients(zip(grads, self.train_weights))
159159
loss = loss.asnumpy()
@@ -167,8 +167,8 @@ def __init__(self, net_with_loss, optimizer, train_weights):
167167
self.optimizer = optimizer
168168
self.train_weights = train_weights
169169

170-
def __call__(self, data, label):
171-
loss = self.net_with_loss(data, label)
170+
def __call__(self, data, label, *args, **kwargs):
171+
loss = self.net_with_loss(data, label, *args, **kwargs)
172172
grads = self.optimizer.gradient(loss, self.train_weights)
173173
self.optimizer.apply_gradients(zip(grads, self.train_weights))
174174
return loss.numpy()
@@ -183,15 +183,15 @@ def __init__(self, net_with_loss, optimizer, train_weights):
183183
self.optimizer = optimizer
184184
self.train_weights = train_weights
185185

186-
def __call__(self, data, label):
186+
def __call__(self, data, label, *args, **kwargs):
187187
# if isinstance(data, dict):
188188
# for k, v in data.items():
189189
# if isinstance(v, torch.Tensor):
190190
# data[k] = v.to(self.device)
191191
# elif isinstance(data, torch.Tensor):
192192
# data = data.to(self.device)
193193
# label = label.to(self.device)
194-
loss = self.net_with_loss(data, label)
194+
loss = self.net_with_loss(data, label, *args, **kwargs)
195195
grads = self.optimizer.gradient(loss, self.train_weights)
196196
self.optimizer.apply_gradients(zip(grads, self.train_weights))
197197
return loss.cpu().detach().numpy()

0 commit comments

Comments
 (0)