2626 sample2all ,
2727 _infer_num_sites ,
2828 _decode_basis_label ,
29+ onehot_d_tensor ,
2930)
3031from .abstractcircuit import AbstractCircuit
3132from .cons import npdtype , backend , dtypestr , contractor , rdtypestr
@@ -410,8 +411,7 @@ def measure_jit(
410411 np .array ([1 , 0 ])
411412 ) + sample [i ] * gates .array_to_tensor (np .array ([0 , 1 ]))
412413 else :
413- vec = backend .one_hot (backend .cast (sample [i ], "int32" ), self ._d )
414- m = backend .cast (vec , dtypestr )
414+ m = onehot_d_tensor (sample [i ], d = self ._d )
415415 g1 = Gate (m )
416416 g1 .id = id (g1 )
417417 g1 .is_dagger = False
@@ -507,29 +507,21 @@ def amplitude_before(self, l: Union[str, Tensor]) -> List[Gate]:
507507 :rtype: List[Gate]
508508 """
509509
510- def _basis_nod (_k : int ) -> Tensor :
511- _vec = np .zeros ((self ._d ,), dtype = npdtype )
512- _vec [_k ] = 1.0
513- return _vec
514-
515510 no , d_edges = self ._copy ()
516511 ms = []
517512 if self .is_dm :
518513 msconj = []
519514 if isinstance (l , str ):
520515 symbols = _decode_basis_label (l , n = self ._nqubits , dim = self ._d )
521516 for k in symbols :
522- n = _basis_nod ( k )
517+ n = onehot_d_tensor ( k , d = self . _d )
523518 ms .append (tn .Node (n ))
524519 if self .is_dm :
525520 msconj .append (tn .Node (n ))
526521 else :
527522 l = backend .cast (l , dtype = dtypestr )
528523 for i in range (self ._nqubits ):
529- endn = backend .cast (
530- backend .one_hot (backend .cast (l [i ], "int32" ), self ._d ),
531- dtype = dtypestr ,
532- )
524+ endn = onehot_d_tensor (l [i ], d = self ._d )
533525 ms .append (tn .Node (endn ))
534526 if self .is_dm :
535527 msconj .append (tn .Node (endn ))
@@ -1040,18 +1032,13 @@ def projected_subsystem(self, traceout: Tensor, left: Tuple[int, ...]) -> Tensor
10401032 :rtype: Tensor
10411033 """
10421034
1043- def _basis_gate (k_tensor : Any ) -> Gate :
1044- vec = backend .one_hot (backend .cast (k_tensor , "int32" ), self ._d )
1045- vec = backend .cast (vec , dtypestr )
1046- return Gate (vec )
1047-
10481035 traceout = backend .cast (traceout , dtypestr )
10491036 nodes , front = self ._copy ()
10501037 L = self ._nqubits
10511038 edges = []
10521039 for i in range (len (traceout )):
10531040 if i not in left :
1054- n = _basis_gate ( traceout [i ])
1041+ n = Gate ( onehot_d_tensor ( traceout [i ], d = self . _d ) )
10551042 nodes .append (n )
10561043 front [i ] ^ n [0 ]
10571044 else :
@@ -1060,7 +1047,7 @@ def _basis_gate(k_tensor: Any) -> Gate:
10601047 if self .is_dm :
10611048 for i in range (len (traceout )):
10621049 if i not in left :
1063- n = _basis_gate ( traceout [i ])
1050+ n = Gate ( onehot_d_tensor ( traceout [i ], d = self . _d ) )
10641051 nodes .append (n )
10651052 front [i + L ] ^ n [0 ]
10661053 else :
0 commit comments