Skip to content

Commit 0c23d08

Browse files
committed
Add script to test finite differences
1 parent 6552bf3 commit 0c23d08

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/test_numerical_jacobian.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from functools import partial
2+
import jax
3+
4+
jax.config.update("jax_enable_x64", True) # double precision
5+
from jax import Array, random
6+
from jax import numpy as jnp
7+
from jsrm.utils.numerical_jacobian import approx_derivative
8+
9+
10+
def test_finite_differences(method = "2-point"):
11+
def fun(x: Array):
12+
return jnp.stack([x[0] * jnp.sin(x[1]), x[0] * jnp.cos(x[1])])
13+
14+
jac_autodiff_fn = jax.jacfwd(fun)
15+
jac_numdiff_fn = partial(approx_derivative, fun, method=method)
16+
17+
rng = random.PRNGKey(0)
18+
for i in range(100):
19+
rng, subrng = random.split(rng)
20+
x = random.uniform(subrng, (2,), minval=-1.0, maxval=1.0)
21+
22+
jac_autodiff = jac_autodiff_fn(x)
23+
jac_numdiff = jac_numdiff_fn(x)
24+
print("x = ", x, "\njac_autodiff = \n", jac_autodiff, "\njac_numdiff = \n", jac_numdiff)
25+
26+
error_jac = jnp.linalg.norm(jac_autodiff - jac_numdiff)
27+
print("error_jac = ", error_jac)
28+
29+
if not jnp.allclose(jac_autodiff, jac_numdiff, atol=1e-6):
30+
raise ValueError("Jacobian mismatch!")
31+
32+
33+
if __name__ == "__main__":
34+
test_finite_differences(method="2-point")
35+
test_finite_differences(method="3-point")

0 commit comments

Comments
 (0)