File tree Expand file tree Collapse file tree 2 files changed +13
-0
lines changed
spinoffs/oryx/oryx/core/interpreters/inverse Expand file tree Collapse file tree 2 files changed +13
-0
lines changed Original file line number Diff line number Diff line change @@ -353,6 +353,18 @@ def f(x, y):
353
353
onp .testing .assert_allclose (
354
354
f_y_ildj (3. ), np .log (np .abs (jax .grad (f_y_inv )(3. ))))
355
355
356
+ def test_sqrt_inverse (self ):
357
+ def f (x ):
358
+ return np .sqrt (x )
359
+ f_inv = core .inverse (f )
360
+ onp .testing .assert_allclose (f_inv (2. ), 4. )
361
+
362
+ def test_sqrt_ildj (self ):
363
+ def f (x ):
364
+ return np .sqrt (x )
365
+ f_ildj = core .ildj (f )
366
+ onp .testing .assert_allclose (f_ildj (3. ), np .log (2. ) + np .log (3. ))
367
+
356
368
357
369
if __name__ == '__main__' :
358
370
os .environ ['XLA_FLAGS' ] = '--xla_force_host_platform_device_count=2'
Original file line number Diff line number Diff line change 48
48
register_elementwise (lax .expm1_p )(np .log1p )
49
49
register_elementwise (lax .log1p_p )(np .expm1 )
50
50
register_elementwise (lax .neg_p )(lambda x : - x )
51
+ register_elementwise (lax .sqrt_p )(np .square )
51
52
52
53
53
54
@register_elementwise (lax .integer_pow_p )
You can’t perform that action at this time.
0 commit comments