@@ -447,25 +447,27 @@ def measure_jit(
447447 sample .append (sign_complex )
448448 p = p * (pu * (- 1 ) ** sign + sign )
449449 else :
450- zero_r = backend .cast (backend .convert_to_tensor (0.0 ), rdtypestr )
451- pu = backend .clip (backend .real (backend .diagonal (rho )), zero_r , one_r )
450+ pu = backend .clip (
451+ backend .real (backend .diagonal (rho )),
452+ backend .convert_to_tensor (0.0 ),
453+ backend .convert_to_tensor (1.0 ),
454+ )
452455 pu = pu / backend .sum (pu )
453456 if status is None :
454- k_out = backend .implicit_randc (
457+ ind = backend .implicit_randc (
455458 a = backend .arange (self ._d ),
456459 shape = 1 ,
457460 p = backend .cast (pu , rdtypestr ),
458- )[0 ]
459- k_out = backend .cast (k_out , "int32" )
461+ )
460462 else :
461- r = backend .real (backend .cast (status [k ], rdtypestr ))
462- cdf = backend .cumsum (pu )
463- k_out = backend .sum (backend .cast (r >= cdf , "int32" ))
464- k_out = backend .clip (
465- k_out ,
466- backend .cast (backend .convert_to_tensor (0 ), "int32" ),
467- backend .cast (backend .convert_to_tensor (self ._d - 1 ), "int32" ),
463+ one_r = backend .cast (backend .convert_to_tensor (1.0 ), rdtypestr )
464+ st = backend .cast (status [k : k + 1 ], rdtypestr )
465+ ind = backend .probability_sample (
466+ shots = 1 ,
467+ p = backend .cast (pu , rdtypestr ),
468+ status = one_r - st ,
468469 )
470+ k_out = backend .cast (ind [0 ], "int32" )
469471 sample .append (backend .cast (k_out , rdtypestr ))
470472 p = p * backend .cast (pu [k_out ], rdtypestr )
471473 sample = backend .real (backend .stack (sample ))
0 commit comments