@@ -203,6 +203,142 @@ def objective_function(params):
203203 print (objective_function (tc .backend .ones (4 )))
204204
205205
206+ def test_ode_evol_jit_grad (highp , jaxb ):
207+ try :
208+ import diffrax # pylint: disable=unused-import
209+ except ImportError :
210+ pytest .skip ("diffrax not installed, skipping test" )
211+
212+ zz_ham = tc .quantum .PauliStringSum2COO ([[3 , 3 , 0 , 0 ], [0 , 3 , 3 , 0 ]], [1 , 1 ])
213+ x_ham = tc .quantum .PauliStringSum2COO ([[1 , 0 , 0 , 0 ], [0 , 1 , 0 , 0 ]], [1 , 1 ])
214+
215+ c = tc .Circuit (4 )
216+ c .x ([1 , 3 ])
217+ psi0 = c .state ()
218+
219+ # Example with parameterized Hamiltonian and optimization
220+ def parametrized_hamiltonian (t , * params ):
221+ # params = [J0, J1, h0, h1] - parameters to optimize
222+ J_t = params [0 ] + params [1 ] * tc .backend .sin (2.0 * t )
223+ h_t = params [2 ] + params [3 ] * tc .backend .cos (1.5 * t )
224+
225+ return J_t * zz_ham + h_t * x_ham
226+
227+ def zz_correlation (state ):
228+ n = int (np .log2 (state .shape [0 ]))
229+ circuit = tc .Circuit (n , inputs = state )
230+ return circuit .expectation_ps (z = [0 , 1 ])
231+
232+ @tc .backend .jit
233+ @tc .backend .value_and_grad
234+ def kv_ode_solver_ (params ):
235+ states = tc .timeevol .ode_evol_global (
236+ parametrized_hamiltonian ,
237+ psi0 ,
238+ tc .backend .convert_to_tensor ([0 , 10.0 ]),
239+ None ,
240+ * params ,
241+ atol = 1.0e-15 ,
242+ rtol = 1.0e-15 ,
243+ solver = "Kvaerno5" ,
244+ ode_backend = "diffrax" ,
245+ )
246+ return tc .backend .real (zz_correlation (states [- 1 ]))
247+
248+ @tc .backend .jit
249+ @tc .backend .value_and_grad
250+ def ts_ode_solver_ (params ):
251+ states = tc .timeevol .ode_evol_global (
252+ parametrized_hamiltonian ,
253+ psi0 ,
254+ tc .backend .convert_to_tensor ([0 , 10.0 ]),
255+ None ,
256+ * params ,
257+ ode_backend = "diffrax" ,
258+ atol = 1.0e-13 ,
259+ rtol = 1.0e-13 ,
260+ dt0 = 0.005 ,
261+ )
262+ return tc .backend .real (zz_correlation (states [- 1 ]))
263+
264+ @tc .backend .jit
265+ @tc .backend .value_and_grad
266+ def do5_ode_solver_ (params ):
267+ states = tc .timeevol .ode_evol_global (
268+ parametrized_hamiltonian ,
269+ psi0 ,
270+ tc .backend .convert_to_tensor ([0 , 10.0 ]),
271+ None ,
272+ * params ,
273+ )
274+ return tc .backend .real (zz_correlation (states [- 1 ]))
275+
276+ paras = np .random .rand (4 )
277+ s1 = kv_ode_solver_ (paras )
278+ s2 = ts_ode_solver_ (paras )
279+ s3 = do5_ode_solver_ (paras )
280+
281+ v1 , g1 = s1
282+ v2 , g2 = s2
283+ v3 , g3 = s3
284+
285+ assert (np .linalg .norm (v1 - v2 ) < 1e-8 ) & (np .linalg .norm (v1 - v3 ) < 1e-8 )
286+ assert (np .linalg .norm (g1 - g2 ) < 1e-8 ) & (np .linalg .norm (g1 - g3 ) < 1e-8 )
287+
288+ ######################################################################
289+
290+ def local_hamiltonian (t , Omega , phi ):
291+ angle = phi * t
292+ coeff = Omega * tc .backend .cos (2.0 * t ) # Amplitude modulation
293+
294+ # Single-qubit Rabi Hamiltonian (2x2 matrix)
295+ hx = coeff * tc .backend .cos (angle ) * tc .gates .x ().tensor
296+ hy = coeff * tc .backend .sin (angle ) * tc .gates .y ().tensor
297+ return hx + hy
298+
299+ # Initial state: GHZ state |0000⟩ + |1111⟩
300+ c = tc .Circuit (4 )
301+ c .h (0 )
302+ for i in range (3 ):
303+ c .cnot (i , i + 1 )
304+ psi0 = c .state ()
305+
306+ # Evolve with local Hamiltonian acting on qubit 1
307+ @tc .backend .jit
308+ @tc .backend .value_and_grad
309+ def ts_ode_solver_local (paras ):
310+ states = tc .timeevol .ode_evol_local (
311+ local_hamiltonian ,
312+ psi0 ,
313+ tc .backend .convert_to_tensor ([0 , 10.0 ]),
314+ [2 ], # Apply to qubit 1
315+ None ,
316+ * paras , # Omega=1.0, phi=2.0
317+ ode_backend = "diffrax" ,
318+ )
319+ return tc .backend .real (zz_correlation (states [- 1 ]))
320+
321+ @tc .backend .jit
322+ @tc .backend .value_and_grad
323+ def do5_ode_solver_local (paras ):
324+ states = tc .timeevol .ode_evol_local (
325+ local_hamiltonian ,
326+ psi0 ,
327+ tc .backend .convert_to_tensor ([0 , 10.0 ]),
328+ [2 ], # Apply to qubit 1
329+ None ,
330+ * paras , # Omega=1.0, phi=2.0
331+ )
332+ return tc .backend .real (zz_correlation (states [- 1 ]))
333+
334+ paras = np .random .rand (2 )
335+ s1 = ts_ode_solver_local (paras )
336+ s2 = do5_ode_solver_local (paras )
337+ v1 , g1 = s1
338+ v2 , g2 = s2
339+ assert (np .linalg .norm (v1 - v2 ) < 1e-8 ) & (np .linalg .norm (g1 - g2 ) < 1e-8 )
340+
341+
206342@pytest .mark .parametrize ("backend" , [lf ("npb" ), lf ("tfb" ), lf ("jaxb" )])
207343def test_ed_evol (backend ):
208344 n = 4
0 commit comments