Skip to content

Commit cc67716

Browse files
committed
new
1 parent d9f7a84 commit cc67716

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

tensorcircuit/timeevol.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,22 +438,22 @@ def _solve_ode(
438438
rtol = solver_kws.get("rtol", 1e-12)
439439
atol = solver_kws.get("atol", 1e-12)
440440
ode_backend = solver_kws.get("ode_backend", "jaxode")
441+
max_steps = solver_kws.get("max_steps", 10000)
441442

442443
ts = backend.convert_to_tensor(times)
443444
ts = backend.cast(ts, dtype=rdtypestr)
444445

445446
if ode_backend == "jaxode":
446447
from jax.experimental.ode import odeint
447448

448-
s1 = odeint(f, s, ts, rtol=rtol, atol=atol, *args)
449+
s1 = odeint(f, s, ts, rtol=rtol, atol=atol, mxstep=max_steps, *args)
449450
return s1
450451

451452
import diffrax
452453

453454
# Ignore complex warning
454455
warnings.simplefilter("ignore", category=UserWarning, append=True)
455456

456-
max_steps = solver_kws.get("max_steps", 10000)
457457
solver = solver_kws.get("solver", "Tsit5")
458458
dt0 = solver_kws.get("dt0", 0.01)
459459
all_solvers = {

tests/test_timeevol.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import tensorcircuit as tc
1212

1313

14-
def test_circuit_ode_evol(jaxb):
14+
def test_circuit_ode_evol(highp, jaxb):
1515
def h_square(t, b):
1616
return (tc.backend.sign(t - 1.0) + 1) / 2 * b * tc.gates.x().tensor
1717

@@ -33,7 +33,11 @@ def h_square_sparse(t, b):
3333
c.cx(0, 1)
3434
c.h(2)
3535
c = tc.timeevol.evol_global(
36-
c, h_square_sparse, 2.0, tc.backend.convert_to_tensor(0.2)
36+
c,
37+
h_square_sparse,
38+
2.0,
39+
tc.backend.convert_to_tensor(0.2),
40+
ode_backend="diffrax",
3741
)
3842
c.rx(1, theta=np.pi - 0.4)
3943
np.testing.assert_allclose(c.expectation_ps(z=[1]), 1.0, atol=1e-5)

0 commit comments

Comments
 (0)