Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 32 additions & 21 deletions tensorcircuit/timeevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,10 @@ def _solve_ode(
args: Any,
solver_kws: Dict[str, Any],
) -> Tensor:
rtol = solver_kws.get("rtol", 1e-12)
atol = solver_kws.get("atol", 1e-12)
rtol = solver_kws.get("rtol", 1e-8)
atol = solver_kws.get("atol", 1e-8)
ode_backend = solver_kws.get("ode_backend", "jaxode")
max_steps = solver_kws.get("max_steps", 10000)
max_steps = solver_kws.get("max_steps", 4096)

ts = backend.convert_to_tensor(times)
ts = backend.cast(ts, dtype=rdtypestr)
Expand Down Expand Up @@ -513,15 +513,21 @@ def ode_evol_local(
:type callback: Optional[Callable[..., Tensor]]
:param args: Additional arguments to pass to the Hamiltonian function.
:param solver_kws: Additional keyword arguments to pass to the ODE solver.
- ode_backend='jaxode'(default) uses ``jax.experimental.ode.odeint``; ode_backend='diffrax'
uses ``diffrax.diffeqsolve``.
- rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would
like the numerical approximation to your equation.
- The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
and only works when ode_backend='diffrax'.
- dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'.
- max_steps (default: 10000) The maximum number of steps to take before quitting the computation
unconditionally and only works when ode_backend='diffrax'.

- ``ode_backend='jaxode'`` (default) uses ``jax.experimental.ode.odeint``; ``ode_backend='diffrax'``
uses ``diffrax.diffeqsolve``.

- ``rtol`` (default: 1e-8) and ``atol`` (default: 1e-8) are used to determine how accurately you would
like the numerical approximation to your equation.

- The ``solver`` parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
and only works when ``ode_backend='diffrax'``.

- ``t0`` (default: 0.01) specifies the initial step size and only works when ``ode_backend='diffrax'``.

- ``max_steps`` (default: 4096) The maximum number of steps to take before quitting the computation
unconditionally and only works when ``ode_backend='diffrax'``.
:type solver_kws: dict

:return: Evolved quantum states at the specified time points. If callback is provided,
returns the callback results; otherwise returns the state vectors.
Expand Down Expand Up @@ -585,17 +591,22 @@ def ode_evol_global(
:param args: Additional arguments to pass to the Hamiltonian function.
:type args: tuple | list
:param solver_kws: Additional keyword arguments to pass to the ODE solver.
- ode_backend='jaxode'(default) uses ``jax.experimental.ode.odeint``; ode_backend='diffrax'
uses ``diffrax.diffeqsolve``.
- rtol (default: 1e-12) and atol (default: 1e-12) are used to determine how accurately you would
like the numerical approximation to your equation.
- The solver parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
and only works when ode_backend='diffrax'.
- dt0 (default: 0.01) specifies the initial step size and only works when ode_backend='diffrax'.
- max_steps (default: 10000) The maximum number of steps to take before quitting the computation
unconditionally and only works when ode_backend='diffrax'.

- ``ode_backend='jaxode'`` (default) uses ``jax.experimental.ode.odeint``; ``ode_backend='diffrax'``
uses ``diffrax.diffeqsolve``.

- ``rtol`` (default: 1e-8) and ``atol`` (default: 1e-8) are used to determine how accurately you would
like the numerical approximation to your equation.

- The ``solver`` parameter accepts one of {'Tsit5' (default), 'Dopri5', 'Dopri8', 'Kvaerno5'}
and only works when ``ode_backend='diffrax'``.

- ``t0`` (default: 0.01) specifies the initial step size and only works when ``ode_backend='diffrax'``.

- ``max_steps`` (default: 4096) The maximum number of steps to take before quitting the computation
unconditionally and only works when ``ode_backend='diffrax'``.
:type solver_kws: dict

:return: Evolved quantum states at the specified time points. If callback is provided,
returns the callback results; otherwise returns the state vectors.
:rtype: Tensor
Expand Down
14 changes: 8 additions & 6 deletions tests/test_timeevol.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ def local_hamiltonian(t, Omega, phi):
1.0,
2.0, # Omega=1.0, phi=2.0
solver="Dopri8",
atol=1.0e-13,
rtol=1.0e-13,
atol=1.0e-11,
rtol=1.0e-11,
ode_backend="diffrax",
dt0=0.005,
)

np.testing.assert_allclose(states2, states1, atol=1e-10, rtol=0.0)
np.testing.assert_allclose(states0, states1, atol=1e-10, rtol=0.0)
np.testing.assert_allclose(states2, states1, atol=1e-5, rtol=0.0)
np.testing.assert_allclose(states0, states1, atol=1e-5, rtol=0.0)


def test_ode_evol_global(highp, jaxb):
Expand Down Expand Up @@ -270,6 +270,8 @@ def do5_ode_solver_(params):
tc.backend.convert_to_tensor([0, 10.0]),
None,
*params,
atol=1.0e-13,
rtol=1.0e-13,
)
return tc.backend.real(zz_correlation(states[-1]))

Expand Down Expand Up @@ -339,8 +341,8 @@ def do5_ode_solver_local(paras):
v1, g1 = s1
v2, g2 = s2

np.testing.assert_allclose(g1, g2, atol=1e-8, rtol=0)
np.testing.assert_allclose(v1, v2, atol=1e-8, rtol=0)
np.testing.assert_allclose(g1, g2, atol=1e-5, rtol=0)
np.testing.assert_allclose(v1, v2, atol=1e-5, rtol=0)


@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
Expand Down