Skip to content

Commit f3106b2

Browse files
sharadmvtensorflower-gardener
authored andcommitted
[Oryx] Add inverse rules for integer_pow and pow
PiperOrigin-RevId: 376048209
1 parent 909135a commit f3106b2

File tree

3 files changed

+103
-12
lines changed

3 files changed

+103
-12
lines changed

spinoffs/oryx/oryx/core/interpreters/inverse/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ def ildj_rule(incells, outcells, **params):
257257
incell, = incells
258258
if not incell.top() and outcell.top():
259259
val = outcell.val
260-
f_sum = lambda x: f(x).sum()
261-
ildj_ = outcell.ildj + np.log(jax.grad(f_sum)(val))
262-
ndslice = NDSlice.new(f(val), ildj_)
260+
f_sum = lambda x: f(x, **params).sum()
261+
ildj_ = outcell.ildj + np.log(np.abs(jax.grad(f_sum)(val)))
262+
ndslice = NDSlice.new(f(val, **params), ildj_)
263263
incells = [InverseAndILDJ(outcell.aval, [ndslice])]
264264
elif not outcell.top() and incell.top():
265265
outcells = [InverseAndILDJ.new(prim.bind(incell.val, **params))]

spinoffs/oryx/oryx/core/interpreters/inverse/inverse_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from absl.testing import absltest
1919
import jax
20+
from jax import lax
2021
import jax.numpy as np
2122
import numpy as onp
2223

@@ -296,6 +297,62 @@ def naive_logit(x):
296297
tfb.Sigmoid().forward_log_det_jacobian(
297298
-100., 0))
298299

300+
def test_integer_pow_inverse(self):
301+
def f(x):
302+
return lax.integer_pow(x, 2)
303+
f_inv = core.inverse(f)
304+
onp.testing.assert_allclose(f_inv(2.), np.sqrt(2.))
305+
def f2(x):
306+
return lax.integer_pow(x, 3)
307+
f2_inv = core.inverse(f2)
308+
onp.testing.assert_allclose(f2_inv(2.), onp.cbrt(2.))
309+
310+
def test_integer_pow_ildj(self):
311+
def f(x):
312+
return lax.integer_pow(x, 2)
313+
f_ildj = core.ildj(f)
314+
onp.testing.assert_allclose(
315+
f_ildj(2.), tfb.Power(2.).inverse_log_det_jacobian(2.))
316+
def f2(x):
317+
return lax.integer_pow(x, 3)
318+
f2_ildj = core.ildj(f2)
319+
onp.testing.assert_allclose(
320+
f2_ildj(2.), tfb.Power(3.).inverse_log_det_jacobian(2.))
321+
322+
def test_reciprocal_inverse(self):
323+
def f(x):
324+
return np.reciprocal(x)
325+
f_inv = core.inverse(f)
326+
onp.testing.assert_allclose(f_inv(2.), 0.5)
327+
328+
def test_reciprocal_ildj(self):
329+
def f(x):
330+
return np.reciprocal(x)
331+
f_ildj = core.ildj(f)
332+
onp.testing.assert_allclose(f_ildj(2.), onp.log(1 / 4.))
333+
334+
def test_pow_inverse(self):
335+
def f(x, y):
336+
return lax.pow(x, y)
337+
f_x_inv = core.inverse(lambda x: f(x, 2.))
338+
onp.testing.assert_allclose(f_x_inv(2.), np.sqrt(2.))
339+
f_y_inv = core.inverse(lambda y: f(2., y))
340+
onp.testing.assert_allclose(f_y_inv(3.), np.log(3.) / np.log(2.))
341+
342+
def test_pow_ildj(self):
343+
def f(x, y):
344+
return lax.pow(x, y)
345+
f_x_ildj = core.ildj(lambda x: f(x, 2.))
346+
onp.testing.assert_allclose(
347+
f_x_ildj(3.), tfb.Power(2.).inverse_log_det_jacobian(3.))
348+
f_y_ildj = core.ildj(lambda y: f(2., y))
349+
f_y_inv = core.inverse(lambda y: f(2., y))
350+
y = f_y_inv(3.)
351+
onp.testing.assert_allclose(
352+
f_y_ildj(3.), -np.log(np.abs(jax.grad(lambda y: f(2., y))(y))))
353+
onp.testing.assert_allclose(
354+
f_y_ildj(3.), np.log(np.abs(jax.grad(f_y_inv)(3.))))
355+
299356

300357
if __name__ == '__main__':
301358
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'

spinoffs/oryx/oryx/core/interpreters/inverse/rules.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,38 @@
5050
register_elementwise(lax.neg_p)(lambda x: -x)
5151

5252

53+
@register_elementwise(lax.integer_pow_p)
54+
def integer_pow_inverse(z, *, y):
55+
"""Inverse for `integer_pow_p` primitive."""
56+
if y == 0:
57+
raise ValueError('Cannot invert raising to a value to the 0-th power.')
58+
elif y == 1:
59+
return z
60+
elif y == -1:
61+
return np.reciprocal(z)
62+
elif y == 2:
63+
return np.sqrt(z)
64+
return lax.pow(z, 1. / y)
65+
66+
67+
def pow_left(x, z, ildj_):
68+
# x ** y = z
69+
# y = f^-1(z) = log(z) / log(x)
70+
# grad(f^-1)(z) = 1 / (z log(x))
71+
# log(grad(f^-1)(z)) = log(1 / (z log(x))) = -log(z) - log(log(x))
72+
return np.log(z) / np.log(x), ildj_ - np.log(z) - np.log(np.log(x))
73+
74+
75+
def pow_right(y, z, ildj_):
76+
# x ** y = z
77+
# x = f^-1(z) = z ** (1 / y)
78+
# grad(f^-1)(z) = 1 / y * z ** (1 / y - 1)
79+
# log(grad(f^-1)(z)) = (1 / y - 1)log(z) - log(y)
80+
y_inv = np.reciprocal(y)
81+
return lax.pow(z, y_inv), ildj_ + (y_inv - 1.) * np.log(z) - np.log(y)
82+
register_binary(lax.pow_p)(pow_left, pow_right)
83+
84+
5385
def add_left(left_val, out_val, ildj_):
5486
return out_val - left_val, ildj_
5587

@@ -182,15 +214,6 @@ def expit_ildj(y):
182214
def logit_ildj(y):
183215
return -jax.nn.softplus(-y) - jax.nn.softplus(y)
184216

185-
# Monkey patching JAX so we can define custom, more numerically stable inverses.
186-
jax.scipy.special.expit = custom_inverse(jax.scipy.special.expit)
187-
jax.scipy.special.logit = custom_inverse(jax.scipy.special.logit)
188-
jax.nn.sigmoid = jax.scipy.special.expit
189-
jax.scipy.special.expit.def_inverse_unary(f_inv=jax.scipy.special.logit,
190-
f_ildj=expit_ildj)
191-
jax.scipy.special.logit.def_inverse_unary(f_inv=jax.scipy.special.expit,
192-
f_ildj=logit_ildj)
193-
194217

195218
def convert_element_type_ildj(incells, outcells, *, new_dtype, **params):
196219
"""InverseAndILDJ rule for convert_element_type primitive."""
@@ -206,3 +229,14 @@ def convert_element_type_ildj(incells, outcells, *, new_dtype, **params):
206229
val, new_dtype=incell.aval.dtype, **params))]
207230
return incells, outcells, None
208231
ildj_registry[lax.convert_element_type_p] = convert_element_type_ildj
232+
233+
234+
# Monkey patching JAX so we can define custom, more numerically stable inverses.
235+
jax.scipy.special.expit = custom_inverse(jax.scipy.special.expit)
236+
jax.scipy.special.logit = custom_inverse(jax.scipy.special.logit)
237+
jax.nn.sigmoid = jax.scipy.special.expit
238+
jax.scipy.special.expit.def_inverse_unary(f_inv=jax.scipy.special.logit,
239+
f_ildj=expit_ildj)
240+
jax.scipy.special.logit.def_inverse_unary(f_inv=jax.scipy.special.expit,
241+
f_ildj=logit_ildj)
242+

0 commit comments

Comments
 (0)