|
1 | 1 | from copy import deepcopy
|
| 2 | +import functools |
2 | 3 | import glob
|
3 | 4 | import os
|
4 | 5 | from absl.testing import absltest
|
@@ -178,6 +179,36 @@ def simple_torch_function(a, b):
|
178 | 179 | # Check that we only lower to HLO twice (once for forward, once for backward).
|
179 | 180 | self.assertEqual(ending_lowerings - starting_lowerings, 2)
|
180 | 181 |
|
| 182 | + def test_assume_pure_avoid_retracing_avoid_rejit_rand(self): |
| 183 | + """Tests that we avoid retracing and re-jitting when using assume_pure.""" |
| 184 | + |
| 185 | + # Arrange: first clear the cache to prevent contamination from other tests. |
| 186 | + xb._JAX_TO_XLA_COMPUTATION_CACHE.clear() |
| 187 | + starting_lowerings = xb._jax_to_xla_computation_cache_elements() |
| 188 | + trace_counter = 0 |
| 189 | + |
| 190 | + @functools.partial(assume_pure, add_rng_seed_argument=True) |
| 191 | + def simple_torch_function(a, b): |
| 192 | + nonlocal trace_counter |
| 193 | + trace_counter += 1 |
| 194 | + return torch.sin(a @ b) |
| 195 | + |
| 196 | + # Act: simulate a training loop. |
| 197 | + for i in range(5): |
| 198 | + a = torch.ones((3, 3), device='xla', requires_grad=True) |
| 199 | + o = simple_torch_function(a, a, rng_seed=i) |
| 200 | + o.sum().backward() |
| 201 | + torch_xla.sync() |
| 202 | + |
| 203 | + # Assert |
| 204 | + ending_lowerings = xb._jax_to_xla_computation_cache_elements() |
| 205 | + |
| 206 | + # Check that we only trace once. |
| 207 | + self.assertEqual(trace_counter, 1) |
| 208 | + |
| 209 | + # Check that we only lower to HLO twice (once for forward, once for backward). |
| 210 | + self.assertEqual(ending_lowerings - starting_lowerings, 2) |
| 211 | + |
181 | 212 | def test_assume_pure_matmul_grads(self):
|
182 | 213 | """Tests matmul with all inputs requiring gradients."""
|
183 | 214 |
|
@@ -445,6 +476,48 @@ def torch_func(a, b):
|
445 | 476 | self.assertTrue(MAGIC_STRING in proto_str,
|
446 | 477 | f'Expected "{MAGIC_STRING}" trace in: {path}')
|
447 | 478 |
|
| 479 | + def test_assume_pure_with_rng(self): |
| 480 | + |
| 481 | + def add_randn(a): |
| 482 | + return a + torch.rand_like(a) |
| 483 | + |
| 484 | + add_randn_p = assume_pure(add_randn, add_rng_seed_argument=True) |
| 485 | + |
| 486 | + a = torch.randn((2, 2), device='xla') |
| 487 | + with self.assertRaises(AssertionError): |
| 488 | + # did not pass rng key |
| 489 | + add_randn_p(a) |
| 490 | + |
| 491 | + res1 = add_randn_p(a, rng_seed=0) |
| 492 | + res2 = add_randn_p(a, rng_seed=1) |
| 493 | + # different keys yield different result |
| 494 | + self.assertFalse(torch.allclose(res1, res2)) |
| 495 | + |
| 496 | + res1_again = add_randn_p(a, rng_seed=0) |
| 497 | + # same key yields same result |
| 498 | + self.assertTrue(torch.allclose(res1, res1_again)) |
| 499 | + |
| 500 | + def test_assume_pure_with_many_random(self): |
| 501 | + |
| 502 | + def many_rand(a): |
| 503 | + a = torch.rand_like(a) |
| 504 | + b = torch.rand_like(a) |
| 505 | + c = torch.rand_like(a) |
| 506 | + return c |
| 507 | + |
| 508 | + randn_p = assume_pure(many_rand, add_rng_seed_argument=True) |
| 509 | + |
| 510 | + a = torch.randn((2, 2), device='xla') |
| 511 | + |
| 512 | + res1 = randn_p(a, rng_seed=0) |
| 513 | + res2 = randn_p(a, rng_seed=1) |
| 514 | + # different keys yield different result |
| 515 | + self.assertFalse(torch.allclose(res1, res2)) |
| 516 | + |
| 517 | + res1_again = randn_p(a, rng_seed=0) |
| 518 | + # same key yields same result |
| 519 | + self.assertTrue(torch.allclose(res1, res1_again)) |
| 520 | + |
448 | 521 |
|
449 | 522 | FLAGS = flags.FLAGS
|
450 | 523 | flags.DEFINE_integer(
|
|
0 commit comments