@@ -307,7 +307,7 @@ def tf_train(
307307 val_loss , val_acc , n_iter = 0 , 0 , 0
308308 for X_batch , y_batch in test_dataset :
309309 _logits = network (X_batch ) # is_train=False, disable dropout
310- val_loss += loss_fn (_logits , y_batch , name = 'eval_loss' )
310+ val_loss += loss_fn (_logits , y_batch )
311311 if metrics :
312312 metrics .update (_logits , y_batch )
313313 val_acc += metrics .result ()
@@ -360,7 +360,7 @@ def ms_train(
360360 val_loss , val_acc , n_iter = 0 , 0 , 0
361361 for X_batch , y_batch in test_dataset :
362362 _logits = network (X_batch )
363- val_loss += loss_fn (_logits , y_batch , name = 'eval_loss' )
363+ val_loss += loss_fn (_logits , y_batch )
364364 if metrics :
365365 metrics .update (_logits , y_batch )
366366 val_acc += metrics .result ()
@@ -414,7 +414,7 @@ def pd_train(
414414 val_loss , val_acc , n_iter = 0 , 0 , 0
415415 for X_batch , y_batch in test_dataset :
416416 _logits = network (X_batch ) # is_train=False, disable dropout
417- val_loss += loss_fn (_logits , y_batch , name = 'eval_loss' )
417+ val_loss += loss_fn (_logits , y_batch )
418418 if metrics :
419419 metrics .update (_logits , y_batch )
420420 val_acc += metrics .result ()
@@ -468,7 +468,7 @@ def th_train(
468468 val_loss , val_acc , n_iter = 0 , 0 , 0
469469 for X_batch , y_batch in test_dataset :
470470 _logits = network (X_batch ) # is_train=False, disable dropout
471- val_loss += loss_fn (_logits , y_batch , name = 'eval_loss' )
471+ val_loss += loss_fn (_logits , y_batch )
472472 if metrics :
473473 metrics .update (_logits , y_batch )
474474 val_acc += metrics .result ()
0 commit comments