Skip to content

Commit 193f704

Browse files
authored
adding tolerance for numeric test of checkpointing (#9404)
There was initially concern that the numerics should be exact between activation remat and not. The rematerialized activation should be precise, however, the XLA compiler may re-order the ops, so the final update may deviate slightly, and the final loss of the model could vary even more than that.
1 parent 7e3efc5 commit 193f704

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

test/spmd/test_train_spmd_linear_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def test_basic(self):
5050
with extended_argv(['--use_gradient_checkpointing']):
5151
checkpointing_losses, checkpointing_result = train_and_evaluate()
5252
# Verify that the runs match with and without checkpointing.
53-
assert torch.allclose(baseline_result, checkpointing_result)
53+
assert torch.allclose(baseline_result, checkpointing_result, atol=0.005)
5454
assert all(
55-
torch.allclose(baseline_loss, checkpointing_loss)
55+
torch.allclose(baseline_loss, checkpointing_loss, atol=0.00002)
5656
for baseline_loss, checkpointing_loss in zip(
5757
baseline_losses, checkpointing_losses))
5858

0 commit comments

Comments
 (0)