Skip to content

Commit 497fc1e

Browse files
sharadmvtensorflower-gardener
authored andcommitted
[Oryx] Add inverse rule for sqrt
PiperOrigin-RevId: 376972734
1 parent 76b768a commit 497fc1e

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

spinoffs/oryx/oryx/core/interpreters/inverse/inverse_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,18 @@ def f(x, y):
353353
onp.testing.assert_allclose(
354354
f_y_ildj(3.), np.log(np.abs(jax.grad(f_y_inv)(3.))))
355355

356+
def test_sqrt_inverse(self):
357+
def f(x):
358+
return np.sqrt(x)
359+
f_inv = core.inverse(f)
360+
onp.testing.assert_allclose(f_inv(2.), 4.)
361+
362+
def test_sqrt_ildj(self):
363+
def f(x):
364+
return np.sqrt(x)
365+
f_ildj = core.ildj(f)
366+
onp.testing.assert_allclose(f_ildj(3.), np.log(2.) + np.log(3.))
367+
356368

357369
if __name__ == '__main__':
358370
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'

spinoffs/oryx/oryx/core/interpreters/inverse/rules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
register_elementwise(lax.expm1_p)(np.log1p)
4949
register_elementwise(lax.log1p_p)(np.expm1)
5050
register_elementwise(lax.neg_p)(lambda x: -x)
51+
register_elementwise(lax.sqrt_p)(np.square)
5152

5253

5354
@register_elementwise(lax.integer_pow_p)

0 commit comments

Comments
 (0)