@@ -302,3 +302,162 @@ def exp(theta):
302302 return c .expectation_ps (z = [- 3 ], reuse = False )
303303
304304 assert len (exp (0.3 )) == 9
305+
306+
307+ @pytest .mark .parametrize ("backend" , [lf ("jaxb" ), lf ("tfb" )])
308+ def test_parameter_shift_grad (backend ):
309+ def f (params ):
310+ c = tc .Circuit (2 )
311+ c .rx (0 , theta = params [0 ])
312+ c .ry (1 , theta = params [1 ])
313+ c .cnot (0 , 1 )
314+ return tc .backend .real (c .expectation_ps (z = [0 , 1 ]))
315+
316+ params = tc .array_to_tensor (np .array ([0.1 , 0.2 ]))
317+
318+ # Standard AD gradient
319+ g_ad = tc .backend .grad (f )(params )
320+
321+ # Parameter shift gradient
322+ g_ps = experimental .parameter_shift_grad (f )(params )
323+
324+ np .testing .assert_allclose (g_ad , g_ps , atol = 1e-5 )
325+
326+
327+ @pytest .mark .parametrize ("backend" , [lf ("jaxb" )])
328+ def test_parameter_shift_grad_v2 (backend ):
329+ # v2 is mainly for jax and supports randomness
330+ def f (params ):
331+ c = tc .Circuit (2 )
332+ c .rx (0 , theta = params [0 ])
333+ c .ry (1 , theta = params [1 ])
334+ return tc .backend .real (c .expectation_ps (z = [0 ]))
335+
336+ params = tc .array_to_tensor (np .array ([0.5 , 0.5 ]))
337+ g_ps = experimental .parameter_shift_grad_v2 (f )(params )
338+ g_ad = tc .backend .grad (f )(params )
339+ np .testing .assert_allclose (g_ps , g_ad , atol = 1e-5 )
340+
341+
342+ def test_broadcast_py_object_single_process (jaxb ):
343+ # In a single process environment, broadcast should just return the object
344+ # though it uses jax.experimental.multihost_utils.broadcast_one_to_all
345+ obj = {"a" : 1 , "b" : [1 , 2 , 3 ]}
346+ res = experimental .broadcast_py_object (obj )
347+ assert res == obj
348+
349+
350+ @pytest .mark .parametrize ("backend" , [lf ("jaxb" )])
351+ def test_jax_jitted_function_save_load_v2 (backend , tmp_path ):
352+ K = tc .backend
353+
354+ @K .jit
355+ def f (x ):
356+ return x ** 2 + 1.0
357+
358+ x = K .ones ([2 ])
359+ path = os .path .join (tmp_path , "f.bin" )
360+ experimental .jax_jitted_function_save (path , f , x )
361+
362+ f_load = experimental .jax_jitted_function_load (path )
363+ np .testing .assert_allclose (f_load (x ), f (x ), atol = 1e-5 )
364+
365+
366+ @pytest .mark .parametrize ("backend" , [lf ("jaxb" ), lf ("tfb" )])
367+ def test_qng_options (backend ):
368+ def f (params ):
369+ c = tc .Circuit (1 )
370+ c .rx (0 , theta = params [0 ])
371+ return c .state ()
372+
373+ params = tc .backend .ones ([1 ])
374+ # test different options in qng to hit more lines
375+ qng_fn = experimental .qng (f , mode = "fwd" )
376+ qng_fn (params )
377+
378+ qng_fn2 = experimental .qng (f , mode = "rev" )
379+ qng_fn2 (params )
380+
381+ qng_fn3 = experimental .qng (f , kernel = "dynamics" , postprocess = None )
382+ qng_fn3 (params )
383+
384+
385+ @pytest .mark .parametrize ("backend" , [lf ("jaxb" ), lf ("tfb" )])
386+ def test_qng2_options (backend ):
387+ def f (params ):
388+ c = tc .Circuit (1 )
389+ c .rx (0 , theta = params [0 ])
390+ return c .state ()
391+
392+ params = tc .backend .ones ([1 ])
393+ qng_fn = experimental .qng2 (f , mode = "fwd" )
394+ qng_fn (params )
395+
396+ qng_fn2 = experimental .qng2 (f , mode = "rev" )
397+ qng_fn2 (params )
398+
399+ qng_fn3 = experimental .qng2 (f , kernel = "dynamics" , postprocess = None )
400+ qng_fn3 (params )
401+
402+
403+ def test_vis_extra ():
404+ c = tc .Circuit (2 )
405+ c .h (0 )
406+ c .cx (0 , 1 )
407+ tex = tc .vis .qir2tex (c .to_qir (), 2 )
408+ assert "\\ qw" in tex
409+
410+ assert tc .vis .gate_name_trans ("ccnot" ) == (2 , "not" )
411+ assert tc .vis .gate_name_trans ("h" ) == (0 , "h" )
412+
413+
414+ def test_cons_extra (jaxb ):
415+ # set_function_backend
416+ @tc .cons .set_function_backend ("jax" )
417+ def f ():
418+ return tc .backend .name
419+
420+ # set_function_dtype
421+ @tc .cons .set_function_dtype ("complex128" )
422+ def g ():
423+ return tc .dtypestr
424+
425+
426+ def test_ascii_art ():
427+ # hit some lines in asciiart.py
428+
429+ try :
430+ tc .set_ascii ("wrong" )
431+ except AttributeError :
432+ pass
433+
434+ # lucky() is only available after set_ascii
435+ assert not hasattr (tc , "lucky" )
436+
437+
438+ def test_utils_extra ():
439+ from tensorcircuit import utils
440+
441+ # return_partial
442+ f = lambda x : [x , x ** 2 , x ** 3 ]
443+ f1 = utils .return_partial (f , return_argnums = 1 )
444+ assert f1 (2 ) == 4
445+ f2 = utils .return_partial (f , return_argnums = [0 , 2 ])
446+ assert f2 (2 ) == (2 , 8 )
447+
448+ # append
449+ f3 = utils .append (lambda x : x ** 2 , lambda x : x + 1 )
450+ assert f3 (2 ) == 5
451+
452+ # is_m1mac
453+ utils .is_m1mac ()
454+
455+ # is_sequence, is_number
456+ assert utils .is_sequence ([1 ])
457+ assert utils .is_number (1.0 )
458+
459+ # benchmark
460+ def h (x ):
461+ return x + 1
462+
463+ utils .benchmark (h , 1.0 , tries = 2 )
0 commit comments