8585
8686# for dense matrices
8787function _partial_transpose (ρ:: QuantumObject{Operator} , mask:: Vector{Bool} )
88- mask2 = [1 + Int (i) for i in mask]
88+ nsys = length (mask)
89+ mask2 = reverse ([mask[s] ? 2 : 1 for s in 1 : nsys])
8990 # mask2 has elements with values equal to 1 or 2
90- # 1 - the subsystem don't need to be transposed
91- # 2 - the subsystem need be transposed
91+ # 1 - the subsystem (in reversed order) don't need to be transposed
92+ # 2 - the subsystem (in reversed order) need to be transposed
9293
93- nsys = length (mask2)
94- dims = dimensions_to_dims (get_dimensions_to (ρ))
94+ dims_rev = reverse (dimensions_to_dims (get_dimensions_to (ρ)))
9595 pt_dims = reshape (Vector (1 : (2 * nsys)), (nsys, 2 ))
9696 pt_idx = [
9797 [pt_dims[n, mask2[n]] for n in 1 : nsys] # origin value in mask2
9898 [pt_dims[n, 3 - mask2[n]] for n in 1 : nsys] # opposite value in mask2 (1 -> 2, and 2 -> 1)
9999 ]
100100 return QuantumObject (
101- reshape (permutedims (reshape (ρ. data, (dims ... , dims ... )), pt_idx), size (ρ)),
101+ reshape (permutedims (reshape (ρ. data, (dims_rev ... , dims_rev ... )), pt_idx), size (ρ)),
102102 Operator (),
103103 Dimensions (ρ. dimensions. to),
104104 )
@@ -110,7 +110,8 @@ function _partial_transpose(
110110 mask:: Vector{Bool} ,
111111) where {DimsType<: AbstractDimensions }
112112 M, N = size (ρ)
113- dimsTuple = Tuple (dimensions_to_dims (get_dimensions_to (ρ)))
113+ dims_rev = reverse (Tuple (dimensions_to_dims (get_dimensions_to (ρ))))
114+ mask_rev = reverse (mask)
114115 colptr = ρ. data. colptr
115116 rowval = ρ. data. rowval
116117 nzval = ρ. data. nzval
@@ -130,13 +131,13 @@ function _partial_transpose(
130131 I_pt[n] = i
131132 J_pt[n] = j
132133 else
133- ket_pt = [Base. _ind2sub (dimsTuple , i)... ]
134- bra_pt = [Base. _ind2sub (dimsTuple , j)... ]
135- for sys in findall (m -> m, mask )
134+ ket_pt = [Base. _ind2sub (dims_rev , i)... ]
135+ bra_pt = [Base. _ind2sub (dims_rev , j)... ]
136+ for sys in findall (m -> m, mask_rev )
136137 @inbounds ket_pt[sys], bra_pt[sys] = bra_pt[sys], ket_pt[sys]
137138 end
138- I_pt[n] = Base. _sub2ind (dimsTuple , ket_pt... )
139- J_pt[n] = Base. _sub2ind (dimsTuple , bra_pt... )
139+ I_pt[n] = Base. _sub2ind (dims_rev , ket_pt... )
140+ J_pt[n] = Base. _sub2ind (dims_rev , bra_pt... )
140141 end
141142 V_pt[n] = nzval[p]
142143 end
0 commit comments