Skip to content

Commit 8dbdcc6

Browse files
committed
Do not assume only CompositeBasis represents multiple subsystems
1 parent e0ff5f1 commit 8dbdcc6

File tree

12 files changed

+76
-84
lines changed

12 files changed

+76
-84
lines changed

src/metrics.jl

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ T(ρ) = Tr\\{\\sqrt{ρ^† ρ}\\} = \\sum_i |λ_i|
3232
3333
where ``λ_i`` are the eigenvalues of `rho`.
3434
"""
35-
function tracenorm_h(rho::DenseOpType{B,B}) where B
35+
function tracenorm_h(rho::DenseOpType)
36+
check_multiplicable(rho,rho)
3637
s = eigvals(Hermitian(rho.data))
3738
sum(abs.(s))
3839
end
@@ -77,7 +78,7 @@ T(ρ,σ) = \\frac{1}{2} Tr\\{\\sqrt{(ρ - σ)^† (ρ - σ)}\\}.
7778
It calls [`tracenorm`](@ref) which in turn either uses [`tracenorm_h`](@ref)
7879
or [`tracenorm_nh`](@ref) depending if ``ρ-σ`` is hermitian or not.
7980
"""
80-
tracedistance(rho::DenseOpType{B,B}, sigma::DenseOpType{B,B}) where {B} = 0.5*tracenorm(rho - sigma)
81+
tracedistance(rho::DenseOpType, sigma::DenseOpType) = (check_addible(rho,sigma); check_multiplicable(rho,rho); 0.5*tracenorm(rho - sigma))
8182
function tracedistance(rho::AbstractOperator, sigma::AbstractOperator)
8283
throw(ArgumentError("tracedistance not implemented for $(typeof(rho)) and $(typeof(sigma)). Use dense operators instead."))
8384
end
@@ -95,7 +96,7 @@ T(ρ,σ) = \\frac{1}{2} Tr\\{\\sqrt{(ρ - σ)^† (ρ - σ)}\\} = \\frac{1}{2} \
9596
9697
where ``λ_i`` are the eigenvalues of `rho` - `sigma`.
9798
"""
98-
tracedistance_h(rho::DenseOpType{B,B}, sigma::DenseOpType{B,B}) where {B}= 0.5*tracenorm_h(rho - sigma)
99+
tracedistance_h(rho::DenseOpType, sigma::DenseOpType) = (check_addible(rho,sigma); check_multiplicable(rho,rho); 0.5*tracenorm_h(rho - sigma))
99100
function tracedistance_h(rho::AbstractOperator, sigma::AbstractOperator)
100101
throw(ArgumentError("tracedistance_h not implemented for $(typeof(rho)) and $(typeof(sigma)). Use dense operators instead."))
101102
end
@@ -117,12 +118,11 @@ It uses the identity
117118
118119
where ``σ_i`` are the singular values of `rho` - `sigma`.
119120
"""
120-
tracedistance_nh(rho::DenseOpType{B1,B2}, sigma::DenseOpType{B1,B2}) where {B1,B2} = 0.5*tracenorm_nh(rho - sigma)
121+
tracedistance_nh(rho::DenseOpType, sigma::DenseOpType) = (check_addible(rho, sigma); 0.5*tracenorm_nh(rho - sigma))
121122
function tracedistance_nh(rho::AbstractOperator, sigma::AbstractOperator)
122123
throw(ArgumentError("tracedistance_nh not implemented for $(typeof(rho)) and $(typeof(sigma)). Use dense operators instead."))
123124
end
124125

125-
126126
"""
127127
entropy_vn(rho)
128128
@@ -141,7 +141,8 @@ natural logarithm and ``0\\log(0) ≡ 0``.
141141
* `rho`: Density operator of which to calculate Von Neumann entropy.
142142
* `tol=1e-15`: Tolerance for rounding errors in the computed eigenvalues.
143143
"""
144-
function entropy_vn(rho::DenseOpType{B,B}; tol=1e-15) where B
144+
function entropy_vn(rho::DenseOpType; tol=1e-15)
145+
check_multiplicable(rho, rho)
145146
evals::Vector{ComplexF64} = eigvals(rho.data)
146147
entr = zero(eltype(rho))
147148
for d evals
@@ -164,9 +165,10 @@ The Renyi α-entropy of a density operator is defined as
164165
S_α(ρ) = 1/(1-α) \\log(Tr(ρ^α))
165166
```
166167
"""
167-
function entropy_renyi(rho::DenseOpType{B,B}, α::Integer=2) where B
168+
function entropy_renyi(rho::Operator, α::Integer=2)
168169
α < 0 && throw(ArgumentError("α-Renyi entropy is defined for α≥0, α≂̸1"))
169170
α == 1 && throw(ArgumentError("α-Renyi entropy is defined for α≥0, α≂̸1"))
171+
check_multiplicable(rho,rho)
170172

171173
return 1/(1-α) * log(tr(rho^α))
172174
end
@@ -186,7 +188,12 @@ F(ρ, σ) = Tr\\left(\\sqrt{\\sqrt{ρ}σ\\sqrt{ρ}}\\right),
186188
187189
where ``\\sqrt{ρ}=\\sum_n\\sqrt{λ_n}|ψ⟩⟨ψ|``.
188190
"""
189-
fidelity(rho::DenseOpType{B,B}, sigma::DenseOpType{B,B}) where {B} = tr(sqrt(sqrt(rho.data)*sigma.data*sqrt(rho.data)))
191+
function fidelity(rho::DenseOpType, sigma::DenseOpType)
192+
check_multiplicable(rho,rho)
193+
check_multiplicable(sigma,sigma)
194+
check_multiplicable(rho,sigma)
195+
tr(sqrt(sqrt(rho.data)*sigma.data*sqrt(rho.data)))
196+
end
190197

191198

192199
"""
@@ -196,7 +203,9 @@ Partial transpose of rho with respect to subsystem specified by indices.
196203
197204
The `indices` argument can be a single integer or a collection of integers.
198205
"""
199-
function ptranspose(rho::DenseOpType{B,B}, indices=1) where B<:CompositeBasis
206+
function ptranspose(rho::DenseOpType, indices=1)
207+
length(basis_l(rho)) == length(basis_r(rho)) || throw(ArgumentError())
208+
length(basis_l(rho)) > 1 || throw(ArgumentError())
200209
# adapted from qutip.partial_transpose (https://qutip.org/docs/4.0.2/modules/qutip/partial_transpose.html)
201210
# works as long as QuantumOptics.jl doesn't change the implementation of `tensor`, i.e. tensor(a,b).data = kron(b.data,a.data)
202211
nsys = length(basis_l(rho))
@@ -218,7 +227,7 @@ end
218227
219228
Peres-Horodecki criterion of partial transpose.
220229
"""
221-
PPT(rho::DenseOpType{B,B}, index) where B<:CompositeBasis = all(real.(eigvals(ptranspose(rho, index).data)) .>= 0.0)
230+
PPT(rho::DenseOpType, index) = all(real.(eigvals(ptranspose(rho, index).data)) .>= 0.0)
222231

223232

224233
"""
@@ -233,7 +242,7 @@ N(ρ) = \\frac{\\|ρᵀ\\|-1}{2},
233242
```
234243
where `ρᵀ` is the partial transpose.
235244
"""
236-
negativity(rho::DenseOpType{B,B}, index) where B<:CompositeBasis = 0.5*(tracenorm(ptranspose(rho, index)) - 1.0)
245+
negativity(rho::DenseOpType, index) = 0.5*(tracenorm(ptranspose(rho, index)) - 1.0)
237246

238247

239248
"""
@@ -246,7 +255,7 @@ N(ρ) = \\log₂\\|ρᵀ\\|,
246255
```
247256
where `ρᵀ` is the partial transpose.
248257
"""
249-
logarithmic_negativity(rho::DenseOpType{B,B}, index) where B<:CompositeBasis = log(2, tracenorm(ptranspose(rho, index)))
258+
logarithmic_negativity(rho::DenseOpType, index) = log(2, tracenorm(ptranspose(rho, index)))
250259

251260

252261
"""
@@ -278,35 +287,16 @@ entanglement_entropy(dm(ket)) = 2 * entanglement_entropy(ket)
278287
By default the computed entropy is the Von-Neumann entropy, but a different
279288
function can be provided (for example to compute the entanglement-renyi entropy).
280289
"""
281-
function entanglement_entropy(psi::Ket{B}, partition, entropy_fun=entropy_vn) where B<:CompositeBasis
282-
# check that sites are within the range
283-
@assert all(partition .<= length(psi.basis))
284-
285-
rho = ptrace(psi, partition)
286-
return entropy_fun(rho)
287-
end
290+
entanglement_entropy(psi::Ket, partition, entropy_fun=entropy_vn) = entropy_fun(ptrace(psi, partition))
288291

289-
function entanglement_entropy(rho::DenseOpType{B,B}, partition, args...) where {B<:CompositeBasis}
290-
# check that sites is within the range
291-
hilb = rho.basis_l
292-
N = length(hilb)
293-
all(partition .<= N) || throw(ArgumentError("Indices in partition must be within the bounds of the composite basis."))
294-
length(partition) <= N || throw(ArgumentError("Partition cannot include the whole system."))
292+
function entanglement_entropy(rho::DenseOpType, partition, args...)
293+
check_multiplicable(rho,rho)
294+
N = length(basis_l(rho))
295295

296296
# build the doubled hilbert space for the vectorised dm, normalized like a Ket.
297-
b_doubled = hilb^2
298-
rho_vec = normalize!(Ket(b_doubled, vec(rho.data)))
299-
300-
if partition isa Tuple
301-
partition_ = tuple(partition..., (partition.+N)...)
302-
else
303-
partition_ = vcat(partition, partition.+N)
304-
end
305-
306-
return entanglement_entropy(rho_vec,partition_,args...)
297+
rho_vec = normalize!(Ket(basis_l(rho)^2, vec(rho.data)))
298+
entanglement_entropy(rho_vec, [partition..., (partition.+N)...], args...)
307299
end
308300

309-
entanglement_entropy(state::Ket{B}, partition::Number, args...) where B<:CompositeBasis =
310-
entanglement_entropy(state, [partition], args...)
311-
entanglement_entropy(state::DenseOpType{B,B}, partition::Number, args...) where B<:CompositeBasis =
312-
entanglement_entropy(state, [partition], args...)
301+
entanglement_entropy(state::Ket, partition::Integer, args...) = entanglement_entropy(state, [partition], args...)
302+
entanglement_entropy(state::DenseOpType, partition::Integer, args...) = entanglement_entropy(state, [partition], args...)

src/operators.jl

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ using QuantumInterface: arithmetic_binary_error, arithmetic_unary_error, addnumb
2121
2222
Embed operator acting on a joint Hilbert space where missing indices are filled up with identity operators.
2323
"""
24-
function embed(bl::CompositeBasis, br::CompositeBasis,
25-
indices, op::T) where T<:DataOperator
24+
function embed(bl::Basis, br::Basis, indices, op::T) where T<:DataOperator
2625
(length(bl) == length(br)) || throw(ArgumentError("Must have length(bl) == length(br) in embed"))
2726
N = length(bl)
2827

@@ -70,15 +69,13 @@ function embed(bl::CompositeBasis, br::CompositeBasis,
7069
return unpermuted_op
7170
end
7271

73-
function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
74-
index::Integer, op::T) where T<:DataOperator
75-
76-
N = length(basis_l)
72+
function embed(bl::Basis, br::Basis, index::Integer, op::DataOperator)
73+
N = length(bl)
7774

7875
# Check stuff
79-
@assert length(basis_r) == N
80-
basis_l[index] == op.basis_l || throw(IncompatibleBases())
81-
basis_r[index] == op.basis_r || throw(IncompatibleBases())
76+
@assert length(br) == N
77+
bl[index] == basis_l(op) || throw(IncompatibleBases())
78+
br[index] == basis_r(op) || throw(IncompatibleBases())
8279
check_indices(N, index)
8380

8481
# Build data
@@ -89,17 +86,14 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
8986
while i > 0
9087
if i == index
9188
data = kron(data, op.data)
92-
i -= length(index)
9389
else
94-
bl = basis_l[i]
95-
br = basis_r[i]
96-
id = SparseMatrixCSC{Tnum}(I, dimension(bl), dimension(br))
90+
id = SparseMatrixCSC{Tnum}(I, dimension(bl[i]), dimension(br[i]))
9791
data = kron(data, id)
98-
i -= 1
9992
end
93+
i -= 1
10094
end
10195

102-
return Operator(basis_l, basis_r, data)
96+
return Operator(bl, br, data)
10397
end
10498

10599
"""
@@ -117,13 +111,13 @@ end
117111
# TODO upstream this one
118112
# expect(op::AbstractOperator{B,B}, state::AbstractKet{B}) where B = norm(op * state) ^ 2
119113

120-
function expect(indices, op::AbstractOperator, state::Ket{B}) where {B<:CompositeBasis}
114+
function expect(indices, op::AbstractOperator, state::Ket)
121115
N = length(basis(state))
122116
indices_ = complement(N, indices)
123117
expect(op, ptrace(state, indices_))
124118
end
125119

126-
expect(index::Integer, op::AbstractOperator, state::Ket{B}) where {B<:CompositeBasis} = expect([index], op, state)
120+
expect(index::Integer, op::AbstractOperator, state::Ket) = expect([index], op, state)
127121

128122
"""
129123
variance(op, state)
@@ -138,23 +132,23 @@ function variance(op::AbstractOperator, state::Ket)
138132
state.data'*(op*x).data - (state.data'*x.data)^2
139133
end
140134

141-
function variance(indices, op::AbstractOperator, state::Ket{B}) where {B<:CompositeBasis}
135+
function variance(indices, op::AbstractOperator, state::Ket)
142136
N = length(basis(state))
143137
indices_ = complement(N, indices)
144138
variance(op, ptrace(state, indices_))
145139
end
146140

147-
variance(index::Integer, op::AbstractOperator, state::Ket{B}) where {B<:CompositeBasis} = variance([index], op, state)
141+
variance(index::Integer, op::AbstractOperator, state::Ket) = variance([index], op, state)
148142

149143
# Helper functions to check validity of arguments
150144
function check_ptrace_arguments(a::AbstractOperator, indices)
151-
if !isa(a.basis_l, CompositeBasis) || !isa(a.basis_r, CompositeBasis)
152-
throw(ArgumentError("Partial trace can only be applied onto operators with composite bases."))
153-
end
154145
rank = length(basis_l(a))
155146
if rank != length(basis_r(a))
156147
throw(ArgumentError("Partial trace can only be applied onto operators wich have the same number of subsystems in the left basis and right basis."))
157148
end
149+
if rank < 2
150+
throw(ArgumentError("Partial trace can only be applied to operators over at least two subsystems."))
151+
end
158152
if rank == length(indices)
159153
throw(ArgumentError("Partial trace can't be used to trace out all subsystems - use tr() instead."))
160154
end

src/operators_dense.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,15 @@ function exp(op::T) where {B,T<:DenseOpType{B,B}}
246246
return DenseOperator(op.basis_l, op.basis_r, exp(op.data))
247247
end
248248

249-
function permutesystems(a::Operator{B1,B2}, perm) where {B1<:CompositeBasis,B2<:CompositeBasis}
249+
function permutesystems(a::Operator, perm)
250250
@assert length(a.basis_l) == length(a.basis_r) == length(perm)
251251
@assert isperm(perm)
252252
data = reshape(a.data, [shape(a.basis_l); shape(a.basis_r)]...)
253253
data = permutedims(data, [perm; perm .+ length(perm)])
254254
data = reshape(data, dimension(a.basis_l), dimension(a.basis_r))
255255
return Operator(permutesystems(a.basis_l, perm), permutesystems(a.basis_r, perm), data)
256256
end
257-
permutesystems(a::AdjointOperator{B1,B2}, perm) where {B1<:CompositeBasis,B2<:CompositeBasis} = dagger(permutesystems(dagger(a),perm))
257+
permutesystems(a::AdjointOperator, perm) = dagger(permutesystems(dagger(a),perm))
258258

259259
identityoperator(::Type{S}, ::Type{T}, b1::Basis, b2::Basis) where {S<:DenseOpType,T<:Number} =
260260
Operator(b1, b2, Matrix{T}(I, dimension(b1), dimension(b2)))

src/operators_sparse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function exp(op::T; opts...) where {B,T<:SparseOpType{B,B}}
6464
end
6565
end
6666

67-
function permutesystems(rho::SparseOpPureType{B1,B2}, perm) where {B1<:CompositeBasis,B2<:CompositeBasis}
67+
function permutesystems(rho::SparseOpPureType, perm)
6868
@assert length(rho.basis_l) == length(rho.basis_r) == length(perm)
6969
@assert isperm(perm)
7070
shape_ = [shape(rho.basis_l); shape(rho.basis_r)]

src/pauli.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@ function choi_chi(Nl, Nr)
3636
end
3737

3838
# It's possible to get better asympotic speedups using, e.g. methods from
39+
# https://quantum-journal.org/papers/q-2024-09-05-1461/ (see appendices)
3940
# https://iopscience.iop.org/article/10.1088/1402-4896/ad6499
4041
# https://arxiv.org/abs/2411.00526
41-
# https://quantum-journal.org/papers/q-2024-09-05-1461/ (see appendices)
42+
# https://arxiv.org/abs/2408.06206
43+
# https://quantumcomputing.stackexchange.com/questions/31788/how-to-write-the-iswap-unitary-as-a-linear-combination-of-tensor-products-betw/31790#31790
44+
# So probably using https://github.com/JuliaMath/Hadamard.jl would be best
4245
function _pauli_comp_convert(op, rev)
4346
Nl, Nr = length(basis_l(basis_l(op))), length(basis_l(basis_r(op)))
4447
Vl = pauli_comp(Nl)

src/sparsematrix.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import Base: permutedims
2+
import LinearAlgebra: eigvals, svdvals
23

34
function gemm_sp_dense_small(alpha, M::SparseMatrixCSC, B::AbstractMatrix, result::AbstractMatrix)
45
if isone(alpha)

src/spinors.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ function embed(bl::SumBasis, br::SumBasis, indices, op::LazyDirectSum)
215215
end
216216
# TODO: embed for multiple LazyDirectums?
217217

218+
embed(bl::SumBasis, br::SumBasis, index::Integer, op::LazyDirectSum) = embed(bl,br,[index],op)
219+
218220
# Fast in-place multiplication
219221
function mul!(result::Ket{B1},M::LazyDirectSum{B1,B2},b::Ket{B2},alpha_,beta_) where {B1,B2}
220222
alpha = convert(ComplexF64, alpha_)

src/superoperators.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,12 @@ sprepost(A::Operator, B::Operator) = Operator(KetBraBasis(basis_l(A), basis_r(B)
6464
# return Operator(a1⊗b1, a2⊗b2, data)
6565
#end
6666

67+
# https://discourse.julialang.org/t/permuteddimsarray-slower-than-permutedims/46401
6768
function super_tensor(A, B)
6869
all, alr = basis_l(basis_l(A)), basis_r(basis_l(A))
6970
arl, arr = basis_l(basis_r(A)), basis_r(basis_r(A))
70-
bll, blr = basis_l(basis_l(A)), basis_r(basis_l(A))
71-
brl, brr = basis_l(basis_r(A)), basis_r(basis_r(A))
71+
bll, blr = basis_l(basis_l(B)), basis_r(basis_l(B))
72+
brl, brr = basis_l(basis_r(B)), basis_r(basis_r(B))
7273
data = kron(B.data, A.data)
7374
data = reshape(data, map(dimension, (all, bll, alr, blr, arl, brl, arr, brr)))
7475
data = PermutedDimsArray(data, (1, 3, 2, 4, 5, 6, 7, 8))

src/time_dependent_operator.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,12 @@ function dagger(op::TimeDependentSum)
201201
end
202202
adjoint(op::TimeDependentSum) = dagger(op)
203203

204-
function embed(basis_l::CompositeBasis, basis_r::CompositeBasis, i::Integer, o::TimeDependentSum)
205-
TimeDependentSum(coefficients(o), embed(basis_l, basis_r, i, static_operator(o)), o.current_time)
204+
function embed(bl::Basis, br::Basis, i::Integer, o::TimeDependentSum)
205+
TimeDependentSum(coefficients(o), embed(bl, br, i, static_operator(o)), o.current_time)
206206
end
207207

208-
function embed(basis_l::CompositeBasis, basis_r::CompositeBasis, indices, o::TimeDependentSum)
209-
TimeDependentSum(coefficients(o), embed(basis_l, basis_r, indices, static_operator(o)), o.current_time)
208+
function embed(bl::Basis, br::Basis, indices, o::TimeDependentSum)
209+
TimeDependentSum(coefficients(o), embed(bl, br, indices, static_operator(o)), o.current_time)
210210
end
211211

212212
function +(A::TimeDependentSum, B::TimeDependentSum)

test/test_metrics.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Test
22
using QuantumOpticsBase
3+
import QuantumInterface: IncompatibleBases
34
using SparseArrays, LinearAlgebra
45

56
@testset "metrics" begin
@@ -49,8 +50,8 @@ sigma = tensor(psi2, dagger(psi2))
4950
@test tracedistance(sigma, sigma) 0.
5051

5152
rho = spinup(b1) dagger(coherentstate(b2, 0.1))
52-
@test_throws ArgumentError tracedistance(rho, rho)
53-
@test_throws ArgumentError tracedistance_h(rho, rho)
53+
@test_throws IncompatibleBases tracedistance(rho, rho)
54+
@test_throws IncompatibleBases tracedistance_h(rho, rho)
5455

5556
@test tracedistance_nh(rho, rho) 0.
5657

0 commit comments

Comments
 (0)