Skip to content

Commit 5fc4aac

Browse files
Merge branch 'add-function' of github.com:Huang-Xu-Yang/tensorcircuit-ng into ngpr46
2 parents 004d55e + 9087786 commit 5fc4aac

File tree

2 files changed

+40
-27
lines changed

2 files changed

+40
-27
lines changed

tensorcircuit/timeevol.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -435,10 +435,10 @@ def _solve_ode(
435435
args: Any,
436436
solver_kws: Dict[str, Any],
437437
) -> Tensor:
438-
rtol = solver_kws.get("rtol", 1e-12)
439-
atol = solver_kws.get("atol", 1e-12)
438+
rtol = solver_kws.get("rtol", 1e-8)
439+
atol = solver_kws.get("atol", 1e-8)
440440
ode_backend = solver_kws.get("ode_backend", "jaxode")
441-
max_steps = solver_kws.get("max_steps", 10000)
441+
max_steps = solver_kws.get("max_steps", 4096)
442442

443443
ts = backend.convert_to_tensor(times)
444444
ts = backend.cast(ts, dtype=rdtypestr)
@@ -513,15 +513,21 @@ def ode_evol_local(
513513
:type callback: Optional[Callable[..., Tensor]]
514514
:param args: Additional arguments to pass to the Hamiltonian function.
515515
:param solver_kws: Additional keyword arguments to pass to the ODE solver.
516-
- ode_backend='jaxode'(default) uses ``jax.experimental.ode.odeint``; ode_backend='diffrax'
517-
uses ``diffrax.diffeqsolve``.
518-
- rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would
519-
like the numerical approximation to your equation.
520-
- The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
521-
and only works when ode_backend='diffrax'.
522-
- dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'.
523-
- max_steps (default: 10000) The maximum number of steps to take before quitting the computation
524-
unconditionally and only works when ode_backend='diffrax'.
516+
517+
- ``ode_backend='jaxode'`` (default) uses ``jax.experimental.ode.odeint``; ``ode_backend='diffrax'``
518+
uses ``diffrax.diffeqsolve``.
519+
520+
- ``rtol`` (default: 1e-8) and ``atol`` (default: 1e-8) are used to determine how accurately you would
521+
like the numerical approximation to your equation.
522+
523+
- The ``solver`` parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
524+
and only works when ``ode_backend='diffrax'``.
525+
526+
- ``t0`` (default: 0.01) specifies the initial step size and only works when ``ode_backend='diffrax'``.
527+
528+
- ``max_steps`` (default: 4096) The maximum number of steps to take before quitting the computation
529+
unconditionally and only works when ``ode_backend='diffrax'``.
530+
:type solver_kws: dict
525531
526532
:return: Evolved quantum states at the specified time points. If callback is provided,
527533
returns the callback results; otherwise returns the state vectors.
@@ -585,17 +591,22 @@ def ode_evol_global(
585591
:param args: Additional arguments to pass to the Hamiltonian function.
586592
:type args: tuple | list
587593
:param solver_kws: Additional keyword arguments to pass to the ODE solver.
588-
- ode_backend='jaxode'(default) uses ``jax.experimental.ode.odeint``; ode_backend='diffrax'
589-
uses ``diffrax.diffeqsolve``.
590-
- rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would
591-
like the numerical approximation to your equation.
592-
- The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
593-
and only works when ode_backend='diffrax'.
594-
- dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'.
595-
- max_steps (default: 10000) The maximum number of steps to take before quitting the computation
596-
unconditionally and only works when ode_backend='diffrax'.
597594
595+
- ``ode_backend='jaxode'`` (default) uses ``jax.experimental.ode.odeint``; ``ode_backend='diffrax'``
596+
uses ``diffrax.diffeqsolve``.
597+
598+
- ``rtol`` (default: 1e-8) and ``atol`` (default: 1e-8) are used to determine how accurately you would
599+
like the numerical approximation to your equation.
600+
601+
- The ``solver`` parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
602+
and only works when ``ode_backend='diffrax'``.
603+
604+
- ``t0`` (default: 0.01) specifies the initial step size and only works when ``ode_backend='diffrax'``.
605+
606+
- ``max_steps`` (default: 4096) The maximum number of steps to take before quitting the computation
607+
unconditionally and only works when ``ode_backend='diffrax'``.
598608
:type solver_kws: dict
609+
599610
:return: Evolved quantum states at the specified time points. If callback is provided,
600611
returns the callback results; otherwise returns the state vectors.
601612
:rtype: Tensor

tests/test_timeevol.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ def local_hamiltonian(t, Omega, phi):
105105
1.0,
106106
2.0, # Omega=1.0, phi=2.0
107107
solver="Dopri8",
108-
atol=1.0e-13,
109-
rtol=1.0e-13,
108+
atol=1.0e-11,
109+
rtol=1.0e-11,
110110
ode_backend="diffrax",
111111
dt0=0.005,
112112
)
113113

114-
np.testing.assert_allclose(states2, states1, atol=1e-10, rtol=0.0)
115-
np.testing.assert_allclose(states0, states1, atol=1e-10, rtol=0.0)
114+
np.testing.assert_allclose(states2, states1, atol=1e-5, rtol=0.0)
115+
np.testing.assert_allclose(states0, states1, atol=1e-5, rtol=0.0)
116116

117117

118118
def test_ode_evol_global(highp, jaxb):
@@ -270,6 +270,8 @@ def do5_ode_solver_(params):
270270
tc.backend.convert_to_tensor([0, 10.0]),
271271
None,
272272
*params,
273+
atol=1.0e-13,
274+
rtol=1.0e-13,
273275
)
274276
return tc.backend.real(zz_correlation(states[-1]))
275277

@@ -339,8 +341,8 @@ def do5_ode_solver_local(paras):
339341
v1, g1 = s1
340342
v2, g2 = s2
341343

342-
np.testing.assert_allclose(g1, g2, atol=1e-8, rtol=0)
343-
np.testing.assert_allclose(v1, v2, atol=1e-8, rtol=0)
344+
np.testing.assert_allclose(g1, g2, atol=1e-5, rtol=0)
345+
np.testing.assert_allclose(v1, v2, atol=1e-5, rtol=0)
344346

345347

346348
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])

0 commit comments

Comments
 (0)