We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5043b0f commit 241c094Copy full SHA for 241c094
tensorlayerx/model/core.py
@@ -448,7 +448,7 @@ def th_train(
448
train_acc += metrics.result()
449
metrics.reset()
450
else:
451
- train_acc += (output.argmax(1) == y_batch).type(torch.float).sum().item()
+ train_acc += (output.argmax(1) == y_batch).type(torch.float).mean().item()
452
n_iter += 1
453
454
if print_train_batch:
@@ -474,7 +474,7 @@ def th_train(
474
val_acc += metrics.result()
475
476
477
- val_acc += (_logits.argmax(1) == y_batch).type(torch.float).sum().item()
+ val_acc += (_logits.argmax(1) == y_batch).type(torch.float).mean().item()
478
479
print(" val loss: {}".format(val_loss / n_iter))
480
print(" val acc: {}".format(val_acc / n_iter))
0 commit comments