@@ -203,6 +203,146 @@ def objective_function(params):
203203 print (objective_function (tc .backend .ones (4 )))
204204
205205
206+ def test_ode_evol_jit_grad (highp , jaxb ):
207+ try :
208+ import diffrax # pylint: disable=unused-import
209+ except ImportError :
210+ pytest .skip ("diffrax not installed, skipping test" )
211+
212+ zz_ham = tc .quantum .PauliStringSum2COO ([[3 , 3 , 0 , 0 ], [0 , 3 , 3 , 0 ]], [1 , 1 ])
213+ x_ham = tc .quantum .PauliStringSum2COO ([[1 , 0 , 0 , 0 ], [0 , 1 , 0 , 0 ]], [1 , 1 ])
214+
215+ c = tc .Circuit (4 )
216+ c .x ([1 , 3 ])
217+ psi0 = c .state ()
218+
219+ # Example with parameterized Hamiltonian and optimization
220+ def parametrized_hamiltonian (t , * params ):
221+ # params = [J0, J1, h0, h1] - parameters to optimize
222+ J_t = params [0 ] + params [1 ] * tc .backend .sin (2.0 * t )
223+ h_t = params [2 ] + params [3 ] * tc .backend .cos (1.5 * t )
224+
225+ return J_t * zz_ham + h_t * x_ham
226+
227+ def zz_correlation (state ):
228+ n = int (np .log2 (state .shape [0 ]))
229+ circuit = tc .Circuit (n , inputs = state )
230+ return circuit .expectation_ps (z = [0 , 1 ])
231+
232+ @tc .backend .jit
233+ @tc .backend .value_and_grad
234+ def kv_ode_solver_ (params ):
235+ states = tc .timeevol .ode_evol_global (
236+ parametrized_hamiltonian ,
237+ psi0 ,
238+ tc .backend .convert_to_tensor ([0 , 10.0 ]),
239+ None ,
240+ * params ,
241+ atol = 1.0e-15 ,
242+ rtol = 1.0e-15 ,
243+ solver = "Kvaerno5" ,
244+ ode_backend = "diffrax" ,
245+ )
246+ return tc .backend .real (zz_correlation (states [- 1 ]))
247+
248+ @tc .backend .jit
249+ @tc .backend .value_and_grad
250+ def ts_ode_solver_ (params ):
251+ states = tc .timeevol .ode_evol_global (
252+ parametrized_hamiltonian ,
253+ psi0 ,
254+ tc .backend .convert_to_tensor ([0 , 10.0 ]),
255+ None ,
256+ * params ,
257+ ode_backend = "diffrax" ,
258+ atol = 1.0e-13 ,
259+ rtol = 1.0e-13 ,
260+ dt0 = 0.005 ,
261+ )
262+ return tc .backend .real (zz_correlation (states [- 1 ]))
263+
264+ @tc .backend .jit
265+ @tc .backend .value_and_grad
266+ def do5_ode_solver_ (params ):
267+ states = tc .timeevol .ode_evol_global (
268+ parametrized_hamiltonian ,
269+ psi0 ,
270+ tc .backend .convert_to_tensor ([0 , 10.0 ]),
271+ None ,
272+ * params ,
273+ )
274+ return tc .backend .real (zz_correlation (states [- 1 ]))
275+
276+ paras = np .random .rand (4 )
277+ s1 = kv_ode_solver_ (paras )
278+ s2 = ts_ode_solver_ (paras )
279+ s3 = do5_ode_solver_ (paras )
280+
281+ v1 , g1 = s1
282+ v2 , g2 = s2
283+ v3 , g3 = s3
284+
285+ np .testing .assert_allclose (g1 , g3 , atol = 1e-8 , rtol = 0 )
286+ np .testing .assert_allclose (g1 , g2 , atol = 1e-8 , rtol = 0 )
287+ np .testing .assert_allclose (v1 , v3 , atol = 1e-8 , rtol = 0 )
288+ np .testing .assert_allclose (v1 , v2 , atol = 1e-8 , rtol = 0 )
289+
290+ ######################################################################
291+
292+ def local_hamiltonian (t , Omega , phi ):
293+ angle = phi * t
294+ coeff = Omega * tc .backend .cos (2.0 * t ) # Amplitude modulation
295+
296+ # Single-qubit Rabi Hamiltonian (2x2 matrix)
297+ hx = coeff * tc .backend .cos (angle ) * tc .gates .x ().tensor
298+ hy = coeff * tc .backend .sin (angle ) * tc .gates .y ().tensor
299+ return hx + hy
300+
301+ # Initial state: GHZ state |0000⟩ + |1111⟩
302+ c = tc .Circuit (4 )
303+ c .h (0 )
304+ for i in range (3 ):
305+ c .cnot (i , i + 1 )
306+ psi0 = c .state ()
307+
308+ # Evolve with local Hamiltonian acting on qubit 1
309+ @tc .backend .jit
310+ @tc .backend .value_and_grad
311+ def ts_ode_solver_local (paras ):
312+ states = tc .timeevol .ode_evol_local (
313+ local_hamiltonian ,
314+ psi0 ,
315+ tc .backend .convert_to_tensor ([0 , 10.0 ]),
316+ [2 ], # Apply to qubit 1
317+ None ,
318+ * paras , # Omega=1.0, phi=2.0
319+ ode_backend = "diffrax" ,
320+ )
321+ return tc .backend .real (zz_correlation (states [- 1 ]))
322+
323+ @tc .backend .jit
324+ @tc .backend .value_and_grad
325+ def do5_ode_solver_local (paras ):
326+ states = tc .timeevol .ode_evol_local (
327+ local_hamiltonian ,
328+ psi0 ,
329+ tc .backend .convert_to_tensor ([0 , 10.0 ]),
330+ [2 ], # Apply to qubit 1
331+ None ,
332+ * paras , # Omega=1.0, phi=2.0
333+ )
334+ return tc .backend .real (zz_correlation (states [- 1 ]))
335+
336+ paras = np .random .rand (2 )
337+ s1 = ts_ode_solver_local (paras )
338+ s2 = do5_ode_solver_local (paras )
339+ v1 , g1 = s1
340+ v2 , g2 = s2
341+
342+ np .testing .assert_allclose (g1 , g2 , atol = 1e-8 , rtol = 0 )
343+ np .testing .assert_allclose (v1 , v2 , atol = 1e-8 , rtol = 0 )
344+
345+
206346@pytest .mark .parametrize ("backend" , [lf ("npb" ), lf ("tfb" ), lf ("jaxb" )])
207347def test_ed_evol (backend ):
208348 n = 4
0 commit comments