Skip to content

Commit e63cb19

Browse files
committed
new test about jit and grad
1 parent 827eb67 commit e63cb19

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed

tests/test_timeevol.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,142 @@ def objective_function(params):
203203
print(objective_function(tc.backend.ones(4)))
204204

205205

206+
def test_ode_evol_jit_grad(highp, jaxb):
207+
try:
208+
import diffrax # pylint: disable=unused-import
209+
except ImportError:
210+
pytest.skip("diffrax not installed, skipping test")
211+
212+
zz_ham = tc.quantum.PauliStringSum2COO([[3, 3, 0, 0], [0, 3, 3, 0]], [1, 1])
213+
x_ham = tc.quantum.PauliStringSum2COO([[1, 0, 0, 0], [0, 1, 0, 0]], [1, 1])
214+
215+
c = tc.Circuit(4)
216+
c.x([1, 3])
217+
psi0 = c.state()
218+
219+
# Example with parameterized Hamiltonian and optimization
220+
def parametrized_hamiltonian(t, *params):
221+
# params = [J0, J1, h0, h1] - parameters to optimize
222+
J_t = params[0] + params[1] * tc.backend.sin(2.0 * t)
223+
h_t = params[2] + params[3] * tc.backend.cos(1.5 * t)
224+
225+
return J_t * zz_ham + h_t * x_ham
226+
227+
def zz_correlation(state):
228+
n = int(np.log2(state.shape[0]))
229+
circuit = tc.Circuit(n, inputs=state)
230+
return circuit.expectation_ps(z=[0, 1])
231+
232+
@tc.backend.jit
233+
@tc.backend.value_and_grad
234+
def kv_ode_solver_(params):
235+
states = tc.timeevol.ode_evol_global(
236+
parametrized_hamiltonian,
237+
psi0,
238+
tc.backend.convert_to_tensor([0, 10.0]),
239+
None,
240+
*params,
241+
atol=1.0e-15,
242+
rtol=1.0e-15,
243+
solver="Kvaerno5",
244+
ode_backend="diffrax",
245+
)
246+
return tc.backend.real(zz_correlation(states[-1]))
247+
248+
@tc.backend.jit
249+
@tc.backend.value_and_grad
250+
def ts_ode_solver_(params):
251+
states = tc.timeevol.ode_evol_global(
252+
parametrized_hamiltonian,
253+
psi0,
254+
tc.backend.convert_to_tensor([0, 10.0]),
255+
None,
256+
*params,
257+
ode_backend="diffrax",
258+
atol=1.0e-13,
259+
rtol=1.0e-13,
260+
dt0=0.005,
261+
)
262+
return tc.backend.real(zz_correlation(states[-1]))
263+
264+
@tc.backend.jit
265+
@tc.backend.value_and_grad
266+
def do5_ode_solver_(params):
267+
states = tc.timeevol.ode_evol_global(
268+
parametrized_hamiltonian,
269+
psi0,
270+
tc.backend.convert_to_tensor([0, 10.0]),
271+
None,
272+
*params,
273+
)
274+
return tc.backend.real(zz_correlation(states[-1]))
275+
276+
paras = np.random.rand(4)
277+
s1 = kv_ode_solver_(paras)
278+
s2 = ts_ode_solver_(paras)
279+
s3 = do5_ode_solver_(paras)
280+
281+
v1, g1 = s1
282+
v2, g2 = s2
283+
v3, g3 = s3
284+
285+
assert (np.linalg.norm(v1 - v2) < 1e-8) & (np.linalg.norm(v1 - v3) < 1e-8)
286+
assert (np.linalg.norm(g1 - g2) < 1e-8) & (np.linalg.norm(g1 - g3) < 1e-8)
287+
288+
######################################################################
289+
290+
def local_hamiltonian(t, Omega, phi):
291+
angle = phi * t
292+
coeff = Omega * tc.backend.cos(2.0 * t) # Amplitude modulation
293+
294+
# Single-qubit Rabi Hamiltonian (2x2 matrix)
295+
hx = coeff * tc.backend.cos(angle) * tc.gates.x().tensor
296+
hy = coeff * tc.backend.sin(angle) * tc.gates.y().tensor
297+
return hx + hy
298+
299+
# Initial state: GHZ state |0000⟩ + |1111⟩
300+
c = tc.Circuit(4)
301+
c.h(0)
302+
for i in range(3):
303+
c.cnot(i, i + 1)
304+
psi0 = c.state()
305+
306+
# Evolve with local Hamiltonian acting on qubit 1
307+
@tc.backend.jit
308+
@tc.backend.value_and_grad
309+
def ts_ode_solver_local(paras):
310+
states = tc.timeevol.ode_evol_local(
311+
local_hamiltonian,
312+
psi0,
313+
tc.backend.convert_to_tensor([0, 10.0]),
314+
[2], # Apply to qubit 1
315+
None,
316+
*paras, # Omega=1.0, phi=2.0
317+
ode_backend="diffrax",
318+
)
319+
return tc.backend.real(zz_correlation(states[-1]))
320+
321+
@tc.backend.jit
322+
@tc.backend.value_and_grad
323+
def do5_ode_solver_local(paras):
324+
states = tc.timeevol.ode_evol_local(
325+
local_hamiltonian,
326+
psi0,
327+
tc.backend.convert_to_tensor([0, 10.0]),
328+
[2], # Apply to qubit 1
329+
None,
330+
*paras, # Omega=1.0, phi=2.0
331+
)
332+
return tc.backend.real(zz_correlation(states[-1]))
333+
334+
paras = np.random.rand(2)
335+
s1 = ts_ode_solver_local(paras)
336+
s2 = do5_ode_solver_local(paras)
337+
v1, g1 = s1
338+
v2, g2 = s2
339+
assert (np.linalg.norm(v1 - v2) < 1e-8) & (np.linalg.norm(g1 - g2) < 1e-8)
340+
341+
206342
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
207343
def test_ed_evol(backend):
208344
n = 4

0 commit comments

Comments
 (0)