Skip to content

Commit 8999ba5

Browse files
authored
Make assume_pure able to work with functions that depends on random (#9460)
1 parent 1ccfede commit 8999ba5

File tree

2 files changed

+125
-2
lines changed

2 files changed

+125
-2
lines changed

test/test_assume_pure.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from copy import deepcopy
2+
import functools
23
import glob
34
import os
45
from absl.testing import absltest
@@ -178,6 +179,36 @@ def simple_torch_function(a, b):
178179
# Check that we only lower to HLO twice (once for forward, once for backward).
179180
self.assertEqual(ending_lowerings - starting_lowerings, 2)
180181

182+
def test_assume_pure_avoid_retracing_avoid_rejit_rand(self):
183+
"""Tests that we avoid retracing and re-jitting when using assume_pure."""
184+
185+
# Arrange: first clear the cache to prevent contamination from other tests.
186+
xb._JAX_TO_XLA_COMPUTATION_CACHE.clear()
187+
starting_lowerings = xb._jax_to_xla_computation_cache_elements()
188+
trace_counter = 0
189+
190+
@functools.partial(assume_pure, add_rng_seed_argument=True)
191+
def simple_torch_function(a, b):
192+
nonlocal trace_counter
193+
trace_counter += 1
194+
return torch.sin(a @ b)
195+
196+
# Act: simulate a training loop.
197+
for i in range(5):
198+
a = torch.ones((3, 3), device='xla', requires_grad=True)
199+
o = simple_torch_function(a, a, rng_seed=i)
200+
o.sum().backward()
201+
torch_xla.sync()
202+
203+
# Assert
204+
ending_lowerings = xb._jax_to_xla_computation_cache_elements()
205+
206+
# Check that we only trace once.
207+
self.assertEqual(trace_counter, 1)
208+
209+
# Check that we only lower to HLO twice (once for forward, once for backward).
210+
self.assertEqual(ending_lowerings - starting_lowerings, 2)
211+
181212
def test_assume_pure_matmul_grads(self):
182213
"""Tests matmul with all inputs requiring gradients."""
183214

@@ -445,6 +476,48 @@ def torch_func(a, b):
445476
self.assertTrue(MAGIC_STRING in proto_str,
446477
f'Expected "{MAGIC_STRING}" trace in: {path}')
447478

479+
def test_assume_pure_with_rng(self):
480+
481+
def add_randn(a):
482+
return a + torch.rand_like(a)
483+
484+
add_randn_p = assume_pure(add_randn, add_rng_seed_argument=True)
485+
486+
a = torch.randn((2, 2), device='xla')
487+
with self.assertRaises(AssertionError):
488+
# did not pass rng key
489+
add_randn_p(a)
490+
491+
res1 = add_randn_p(a, rng_seed=0)
492+
res2 = add_randn_p(a, rng_seed=1)
493+
# different keys yield different result
494+
self.assertFalse(torch.allclose(res1, res2))
495+
496+
res1_again = add_randn_p(a, rng_seed=0)
497+
# same key yields same result
498+
self.assertTrue(torch.allclose(res1, res1_again))
499+
500+
def test_assume_pure_with_many_random(self):
501+
502+
def many_rand(a):
503+
a = torch.rand_like(a)
504+
b = torch.rand_like(a)
505+
c = torch.rand_like(a)
506+
return c
507+
508+
randn_p = assume_pure(many_rand, add_rng_seed_argument=True)
509+
510+
a = torch.randn((2, 2), device='xla')
511+
512+
res1 = randn_p(a, rng_seed=0)
513+
res2 = randn_p(a, rng_seed=1)
514+
# different keys yield different result
515+
self.assertFalse(torch.allclose(res1, res2))
516+
517+
res1_again = randn_p(a, rng_seed=0)
518+
# same key yields same result
519+
self.assertTrue(torch.allclose(res1, res1_again))
520+
448521

449522
FLAGS = flags.FLAGS
450523
flags.DEFINE_integer(

torch_xla/experimental/assume_pure.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
@requires_jax
15-
def assume_pure(fn):
15+
def assume_pure(fn, *, add_rng_seed_argument=False):
1616
"""Decorates a pure PyTorch/XLA function to skip expensive re-tracing.
1717
1818
Returns a new function that will only be traced once for each unique
@@ -30,9 +30,59 @@ def assume_pure(fn):
3030
3131
- Other custom PyTorch/XLA operations such as `flash_attention` are not
3232
supported. This limitation may be lifted in the future.
33+
34+
Args:
35+
fn: Callable, the function that is assumed to be pure.
36+
A pure function means, if the inputs are fixed then the output is also fixed
37+
ie. a mathematical function. NOTE: functions that does randomness generation
38+
are NOT pure by this definition.
39+
40+
add_rng_seed_argument: bool, if true, then the returned function will take
41+
an extra 'rng_seed' argument. A function with different rng_seed can produce
42+
different result, so the lifted function becomes pure. rng_seed must be int
43+
44+
45+
Example:
46+
47+
```
48+
def add_randn(a):
49+
return a + torch.randn_like(a)
50+
```
51+
52+
add_randn is not a pure function; but assume_pure(add_randn) assumes it is pure
53+
and hardcodes the rng key at tracing time; making add_randn behaves differently
54+
(thus being incorrect).
55+
56+
if we do add_randn_p = assume_pure(add_randn, add_rng_seed_argument=True), then
57+
we can call add_randn_p(a, rng_seed=0) to get one result and add_randn_p(a, rng_seed=0)
58+
to get another result.
3359
"""
3460
from torchax.interop import jax_view
35-
return j2t_autograd(jax_view(fn))
61+
import torchax
62+
if add_rng_seed_argument:
63+
64+
def new_fn(*args, **kwargs):
65+
env = torchax.default_env()
66+
rng_seed = args[0]
67+
args = args[1:]
68+
env.manual_seed(rng_seed._elem)
69+
return fn(*args, **kwargs)
70+
71+
jitted = j2t_autograd(jax_view(new_fn))
72+
73+
def func_to_return(*args, **kwargs):
74+
rng_seed = kwargs.get('rng_seed')
75+
assert rng_seed is not None, 'Missing keyword argument rng_seed.'
76+
kwargs.pop('rng_seed')
77+
if isinstance(rng_seed, int):
78+
rng_seed = torch.tensor(rng_seed, dtype=torch.uint32, device='xla')
79+
args = (rng_seed, *args)
80+
result = jitted(*args, **kwargs)
81+
return result
82+
83+
return func_to_return
84+
else:
85+
return j2t_autograd(jax_view(fn))
3686

3787

3888
@requires_jax

0 commit comments

Comments
 (0)