Skip to content

Commit 241c094

Browse files
authored
Fix torch backend acc (#21)
1 parent 5043b0f commit 241c094

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensorlayerx/model/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def th_train(
448448
train_acc += metrics.result()
449449
metrics.reset()
450450
else:
451-
train_acc += (output.argmax(1) == y_batch).type(torch.float).sum().item()
451+
train_acc += (output.argmax(1) == y_batch).type(torch.float).mean().item()
452452
n_iter += 1
453453

454454
if print_train_batch:
@@ -474,7 +474,7 @@ def th_train(
474474
val_acc += metrics.result()
475475
metrics.reset()
476476
else:
477-
val_acc += (_logits.argmax(1) == y_batch).type(torch.float).sum().item()
477+
val_acc += (_logits.argmax(1) == y_batch).type(torch.float).mean().item()
478478
n_iter += 1
479479
print(" val loss: {}".format(val_loss / n_iter))
480480
print(" val acc: {}".format(val_acc / n_iter))

0 commit comments

Comments
 (0)