Skip to content

Commit 3422581

Browse files
authored
Fix ptrace (#225)
1 parent 6edd193 commit 3422581

File tree

4 files changed

+121
-37
lines changed

4 files changed

+121
-37
lines changed

src/qobj/arithmetic_and_attributes.jl

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -475,8 +475,9 @@ proj(ψ::QuantumObject{<:AbstractArray{T},BraQuantumObject}) where {T} = ψ' *
475475
@doc raw"""
476476
ptrace(QO::QuantumObject, sel)
477477
478-
[Partial trace](https://en.wikipedia.org/wiki/Partial_trace) of a quantum state `QO` leaving only the dimensions
479-
with the indices present in the `sel` vector.
478+
[Partial trace](https://en.wikipedia.org/wiki/Partial_trace) of a quantum state `QO` leaving only the dimensions with the indices present in the `sel` vector.
479+
480+
Note that this function will always return [`Operator`](@ref). No matter the input [`QuantumObject`](@ref) is a [`Ket`](@ref), [`Bra`](@ref), or [`Operator`](@ref).
480481
481482
# Examples
482483
Two qubits in the state ``\ket{\psi} = \ket{e,g}``:
@@ -514,18 +515,46 @@ Quantum Object: type=Operator dims=[2] size=(2, 2) ishermitian=true
514515
```
515516
"""
516517
function ptrace(QO::QuantumObject{<:AbstractArray,KetQuantumObject}, sel::Union{AbstractVector{Int},Tuple})
517-
length(QO.dims) == 1 && return QO
518+
_non_static_array_warning("sel", sel)
519+
520+
ns = length(sel)
521+
if ns == 0 # return full trace for empty sel
522+
return tr(ket2dm(QO))
523+
else
524+
nd = length(QO.dims)
525+
526+
(any(>(nd), sel) || any(<(1), sel)) && throw(
527+
ArgumentError("Invalid indices in `sel`: $(sel), the given QuantumObject only have $(nd) sub-systems"),
528+
)
529+
(ns != length(unique(sel))) && throw(ArgumentError("Duplicate selection indices in `sel`: $(sel)"))
530+
(nd == 1) && return ket2dm(QO) # ptrace should always return Operator
531+
end
518532

519-
ρtr, dkeep = _ptrace_ket(QO.data, QO.dims, SVector(sel))
533+
_sort_sel = sort(SVector{length(sel),Int}(sel))
534+
ρtr, dkeep = _ptrace_ket(QO.data, QO.dims, _sort_sel)
520535
return QuantumObject(ρtr, type = Operator, dims = dkeep)
521536
end
522537

523538
ptrace(QO::QuantumObject{<:AbstractArray,BraQuantumObject}, sel::Union{AbstractVector{Int},Tuple}) = ptrace(QO', sel)
524539

525540
function ptrace(QO::QuantumObject{<:AbstractArray,OperatorQuantumObject}, sel::Union{AbstractVector{Int},Tuple})
526-
length(QO.dims) == 1 && return QO
541+
_non_static_array_warning("sel", sel)
542+
543+
ns = length(sel)
544+
if ns == 0 # return full trace for empty sel
545+
return tr(QO)
546+
else
547+
nd = length(QO.dims)
548+
549+
(any(>(nd), sel) || any(<(1), sel)) && throw(
550+
ArgumentError("Invalid indices in `sel`: $(sel), the given QuantumObject only have $(nd) sub-systems"),
551+
)
552+
(ns != length(unique(sel))) && throw(ArgumentError("Duplicate selection indices in `sel`: $(sel)"))
553+
(nd == 1) && return QO
554+
end
527555

528-
ρtr, dkeep = _ptrace_oper(QO.data, QO.dims, SVector(sel))
556+
_sort_sel = sort(SVector{length(sel),Int}(sel))
557+
ρtr, dkeep = _ptrace_oper(QO.data, QO.dims, _sort_sel)
529558
return QuantumObject(ρtr, type = Operator, dims = dkeep)
530559
end
531560
ptrace(QO::QuantumObject, sel::Int) = ptrace(QO, SVector(sel))
@@ -538,17 +567,20 @@ function _ptrace_ket(QO::AbstractArray, dims::Union{SVector,MVector}, sel)
538567
qtrace = filter(i -> i sel, 1:nd)
539568
dkeep = dims[sel]
540569
dtrace = dims[qtrace]
541-
# Concatenate sel and qtrace without loosing the length information
542-
sel_qtrace = ntuple(Val(nd)) do i
543-
if i <= length(sel)
544-
@inbounds sel[i]
570+
nt = length(dtrace)
571+
572+
# Concatenate qtrace and sel without losing the length information
573+
# Tuple(qtrace..., sel...)
574+
qtrace_sel = ntuple(Val(nd)) do i
575+
if i <= nt
576+
@inbounds qtrace[i]
545577
else
546-
@inbounds qtrace[i-length(sel)]
578+
@inbounds sel[i-nt]
547579
end
548580
end
549581

550582
vmat = reshape(QO, reverse(dims)...)
551-
topermute = nd + 1 .- sel_qtrace
583+
topermute = reverse(nd + 1 .- qtrace_sel)
552584
vmat = permutedims(vmat, topermute) # TODO: use PermutedDimsArray when Julia v1.11.0 is released
553585
vmat = reshape(vmat, prod(dkeep), prod(dtrace))
554586

@@ -563,27 +595,27 @@ function _ptrace_oper(QO::AbstractArray, dims::Union{SVector,MVector}, sel)
563595
qtrace = filter(i -> i sel, 1:nd)
564596
dkeep = dims[sel]
565597
dtrace = dims[qtrace]
566-
# Concatenate sel and qtrace without loosing the length information
598+
nk = length(dkeep)
599+
nt = length(dtrace)
600+
_2_nt = 2 * nt
601+
602+
# Concatenate qtrace and sel without losing the length information
603+
# Tuple(qtrace..., sel...)
567604
qtrace_sel = ntuple(Val(2 * nd)) do i
568-
if i <= length(qtrace)
605+
if i <= nt
569606
@inbounds qtrace[i]
570-
elseif i <= 2 * length(qtrace)
571-
@inbounds qtrace[i-length(qtrace)] + nd
572-
elseif i <= 2 * length(qtrace) + length(sel)
573-
@inbounds sel[i-length(qtrace)-length(sel)]
607+
elseif i <= _2_nt
608+
@inbounds qtrace[i-nt] + nd
609+
elseif i <= _2_nt + nk
610+
@inbounds sel[i-_2_nt]
574611
else
575-
@inbounds sel[i-length(qtrace)-2*length(sel)] + nd
612+
@inbounds sel[i-_2_nt-nk] + nd
576613
end
577614
end
578615

579616
ρmat = reshape(QO, reverse(vcat(dims, dims))...)
580-
topermute = 2 * nd + 1 .- qtrace_sel
581-
ρmat = permutedims(ρmat, reverse(topermute)) # TODO: use PermutedDimsArray when Julia v1.11.0 is released
582-
583-
## TODO: Check if it works always
584-
585-
# ρmat = row_major_reshape(ρmat, prod(dtrace), prod(dtrace), prod(dkeep), prod(dkeep))
586-
# res = dropdims(mapslices(tr, ρmat, dims=(1,2)), dims=(1,2))
617+
topermute = reverse(2 * nd + 1 .- qtrace_sel)
618+
ρmat = permutedims(ρmat, topermute) # TODO: use PermutedDimsArray when Julia v1.11.0 is released
587619
ρmat = reshape(ρmat, prod(dkeep), prod(dkeep), prod(dtrace), prod(dtrace))
588620
res = map(tr, eachslice(ρmat, dims = (1, 2)))
589621

src/qobj/quantum_object.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -215,17 +215,6 @@ function QuantumObject(
215215
throw(DomainError(size(A), "The size of the array is not compatible with vector or matrix."))
216216
end
217217

218-
_get_size(A::AbstractMatrix) = size(A)
219-
_get_size(A::AbstractVector) = (length(A), 1)
220-
221-
_non_static_array_warning(argname, arg::Tuple{}) =
222-
throw(ArgumentError("The argument $argname must be a Tuple or a StaticVector of non-zero length."))
223-
_non_static_array_warning(argname, arg::Union{SVector{N,T},MVector{N,T},NTuple{N,T}}) where {N,T} = nothing
224-
_non_static_array_warning(argname, arg::AbstractVector{T}) where {T} =
225-
@warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" *
226-
join(arg, ", ") *
227-
")` instead of `$argname = $arg`." maxlog = 1
228-
229218
function _check_dims(dims::Union{AbstractVector{T},NTuple{N,T}}) where {T<:Integer,N}
230219
_non_static_array_warning("dims", dims)
231220
return (all(>(0), dims) && length(dims) > 0) ||

src/utilities.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,14 @@ makeVal(x::Val{T}) where {T} = x
5454
makeVal(x) = Val(x)
5555

5656
getVal(x::Val{T}) where {T} = T
57+
58+
_get_size(A::AbstractMatrix) = size(A)
59+
_get_size(A::AbstractVector) = (length(A), 1)
60+
61+
_non_static_array_warning(argname, arg::Tuple{}) =
62+
throw(ArgumentError("The argument $argname must be a Tuple or a StaticVector of non-zero length."))
63+
_non_static_array_warning(argname, arg::Union{SVector{N,T},MVector{N,T},NTuple{N,T}}) where {N,T} = nothing
64+
_non_static_array_warning(argname, arg::AbstractVector{T}) where {T} =
65+
@warn "The argument $argname should be a Tuple or a StaticVector for better performance. Try to use `$argname = $(Tuple(arg))` or `$argname = SVector(" *
66+
join(arg, ", ") *
67+
")` instead of `$argname = $arg`." maxlog = 1

test/quantum_objects.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,58 @@
596596
@test ρ1.data ρ1_ptr.data atol = 1e-10
597597
@test ρ2.data ρ2_ptr.data atol = 1e-10
598598

599+
ψlist = [rand_ket(2), rand_ket(3), rand_ket(4), rand_ket(5)]
600+
ρlist = [rand_dm(2), rand_dm(3), rand_dm(4), rand_dm(5)]
601+
ψtotal = tensor(ψlist...)
602+
ρtotal = tensor(ρlist...)
603+
sel_tests = [
604+
SVector{0,Int}(),
605+
1,
606+
2,
607+
3,
608+
4,
609+
(1, 2),
610+
(1, 3),
611+
(1, 4),
612+
(2, 3),
613+
(2, 4),
614+
(3, 4),
615+
(1, 2, 3),
616+
(1, 2, 4),
617+
(1, 3, 4),
618+
(2, 3, 4),
619+
(1, 2, 3, 4),
620+
]
621+
for sel in sel_tests
622+
if length(sel) == 0
623+
@test ptrace(ψtotal, sel) 1.0
624+
@test ptrace(ρtotal, sel) 1.0
625+
else
626+
@test ptrace(ψtotal, sel) tensor([ket2dm(ψlist[i]) for i in sel]...)
627+
@test ptrace(ρtotal, sel) tensor([ρlist[i] for i in sel]...)
628+
end
629+
end
630+
@test ptrace(ψtotal, (1, 3, 4)) ptrace(ψtotal, (4, 3, 1)) # check sort of sel
631+
@test ptrace(ρtotal, (1, 3, 4)) ptrace(ρtotal, (3, 1, 4)) # check sort of sel
632+
@test_logs (
633+
:warn,
634+
"The argument sel should be a Tuple or a StaticVector for better performance. Try to use `sel = (1, 2)` or `sel = SVector(1, 2)` instead of `sel = [1, 2]`.",
635+
) ptrace(ψtotal, [1, 2])
636+
@test_logs (
637+
:warn,
638+
"The argument sel should be a Tuple or a StaticVector for better performance. Try to use `sel = (1, 2)` or `sel = SVector(1, 2)` instead of `sel = [1, 2]`.",
639+
) ptrace(ρtotal, [1, 2])
640+
@test_throws ArgumentError ptrace(ψtotal, 0)
641+
@test_throws ArgumentError ptrace(ψtotal, 5)
642+
@test_throws ArgumentError ptrace(ψtotal, (0, 2))
643+
@test_throws ArgumentError ptrace(ψtotal, (2, 5))
644+
@test_throws ArgumentError ptrace(ψtotal, (2, 2, 3))
645+
@test_throws ArgumentError ptrace(ρtotal, 0)
646+
@test_throws ArgumentError ptrace(ρtotal, 5)
647+
@test_throws ArgumentError ptrace(ρtotal, (0, 2))
648+
@test_throws ArgumentError ptrace(ρtotal, (2, 5))
649+
@test_throws ArgumentError ptrace(ρtotal, (2, 2, 3))
650+
599651
@testset "Type Inference (ptrace)" begin
600652
@inferred ptrace(ρ, 1)
601653
@inferred ptrace(ρ, 2)

0 commit comments

Comments
 (0)