@@ -135,9 +135,9 @@ def __init__(self, net_with_loss, optimizer, train_weights):
135135 self .optimizer = optimizer
136136 self .train_weights = train_weights
137137
138- def __call__ (self , data , label ):
138+ def __call__ (self , data , label , * args , ** kwargs ):
139139 with tf .GradientTape () as tape :
140- loss = self .net_with_loss (data , label )
140+ loss = self .net_with_loss (data , label , * args , ** kwargs )
141141 grad = tape .gradient (loss , self .train_weights )
142142 self .optimizer .apply_gradients (zip (grad , self .train_weights ))
143143 return loss .numpy ()
@@ -152,8 +152,8 @@ def __init__(self, net_with_loss, optimizer, train_weights):
152152 self .net_with_loss = net_with_loss
153153 self .train_network = GradWrap (net_with_loss , train_weights )
154154
155- def __call__ (self , data , label ):
156- loss = self .net_with_loss (data , label )
155+ def __call__ (self , data , label , * args , ** kwargs ):
156+ loss = self .net_with_loss (data , label , * args , ** kwargs )
157157 grads = self .train_network (data , label )
158158 self .optimizer .apply_gradients (zip (grads , self .train_weights ))
159159 loss = loss .asnumpy ()
@@ -167,8 +167,8 @@ def __init__(self, net_with_loss, optimizer, train_weights):
167167 self .optimizer = optimizer
168168 self .train_weights = train_weights
169169
170- def __call__ (self , data , label ):
171- loss = self .net_with_loss (data , label )
170+ def __call__ (self , data , label , * args , ** kwargs ):
171+ loss = self .net_with_loss (data , label , * args , ** kwargs )
172172 grads = self .optimizer .gradient (loss , self .train_weights )
173173 self .optimizer .apply_gradients (zip (grads , self .train_weights ))
174174 return loss .numpy ()
@@ -183,15 +183,15 @@ def __init__(self, net_with_loss, optimizer, train_weights):
183183 self .optimizer = optimizer
184184 self .train_weights = train_weights
185185
186- def __call__ (self , data , label ):
186+ def __call__ (self , data , label , * args , ** kwargs ):
187187 # if isinstance(data, dict):
188188 # for k, v in data.items():
189189 # if isinstance(v, torch.Tensor):
190190 # data[k] = v.to(self.device)
191191 # elif isinstance(data, torch.Tensor):
192192 # data = data.to(self.device)
193193 # label = label.to(self.device)
194- loss = self .net_with_loss (data , label )
194+ loss = self .net_with_loss (data , label , * args , ** kwargs )
195195 grads = self .optimizer .gradient (loss , self .train_weights )
196196 self .optimizer .apply_gradients (zip (grads , self .train_weights ))
197197 return loss .cpu ().detach ().numpy ()
0 commit comments