Skip to content

Commit 815505d

Browse files
committed
Do not assume only CompositeBasis represents multiple subsystems
1 parent 774662e commit 815505d

File tree

10 files changed

+74
-83
lines changed

10 files changed

+74
-83
lines changed

src/metrics.jl

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ T(ρ) = Tr\\{\\sqrt{ρ^† ρ}\\} = \\sum_i |λ_i|
3232
3333
where ``λ_i`` are the eigenvalues of `rho`.
3434
"""
35-
function tracenorm_h(rho::DenseOpType{B,B}) where B
35+
function tracenorm_h(rho::DenseOpType)
36+
check_multiplicable(rho,rho)
3637
s = eigvals(Hermitian(rho.data))
3738
sum(abs.(s))
3839
end
@@ -77,7 +78,7 @@ T(ρ,σ) = \\frac{1}{2} Tr\\{\\sqrt{(ρ - σ)^† (ρ - σ)}\\}.
7778
It calls [`tracenorm`](@ref) which in turn either uses [`tracenorm_h`](@ref)
7879
or [`tracenorm_nh`](@ref) depending if ``ρ-σ`` is hermitian or not.
7980
"""
80-
tracedistance(rho::DenseOpType{B,B}, sigma::DenseOpType{B,B}) where {B} = 0.5*tracenorm(rho - sigma)
81+
tracedistance(rho::DenseOpType, sigma::DenseOpType) = (check_addible(rho,sigma); check_multiplicable(rho,rho); 0.5*tracenorm(rho - sigma))
8182
function tracedistance(rho::AbstractOperator, sigma::AbstractOperator)
8283
throw(ArgumentError("tracedistance not implemented for $(typeof(rho)) and $(typeof(sigma)). Use dense operators instead."))
8384
end
@@ -95,7 +96,7 @@ T(ρ,σ) = \\frac{1}{2} Tr\\{\\sqrt{(ρ - σ)^† (ρ - σ)}\\} = \\frac{1}{2} \
9596
9697
where ``λ_i`` are the eigenvalues of `rho` - `sigma`.
9798
"""
98-
tracedistance_h(rho::DenseOpType{B,B}, sigma::DenseOpType{B,B}) where {B}= 0.5*tracenorm_h(rho - sigma)
99+
tracedistance_h(rho::DenseOpType, sigma::DenseOpType) = (check_addible(rho,sigma); check_multiplicable(rho,rho); 0.5*tracenorm_h(rho - sigma))
99100
function tracedistance_h(rho::AbstractOperator, sigma::AbstractOperator)
100101
throw(ArgumentError("tracedistance_h not implemented for $(typeof(rho)) and $(typeof(sigma)). Use dense operators instead."))
101102
end
@@ -117,12 +118,11 @@ It uses the identity
117118
118119
where ``σ_i`` are the singular values of `rho` - `sigma`.
119120
"""
120-
tracedistance_nh(rho::DenseOpType{B1,B2}, sigma::DenseOpType{B1,B2}) where {B1,B2} = 0.5*tracenorm_nh(rho - sigma)
121+
tracedistance_nh(rho::DenseOpType, sigma::DenseOpType) = (check_addible(rho, sigma); 0.5*tracenorm_nh(rho - sigma))
121122
function tracedistance_nh(rho::AbstractOperator, sigma::AbstractOperator)
122123
throw(ArgumentError("tracedistance_nh not implemented for $(typeof(rho)) and $(typeof(sigma)). Use dense operators instead."))
123124
end
124125

125-
126126
"""
127127
entropy_vn(rho)
128128
@@ -141,7 +141,8 @@ natural logarithm and ``0\\log(0) ≡ 0``.
141141
* `rho`: Density operator of which to calculate Von Neumann entropy.
142142
* `tol=1e-15`: Tolerance for rounding errors in the computed eigenvalues.
143143
"""
144-
function entropy_vn(rho::DenseOpType{B,B}; tol=1e-15) where B
144+
function entropy_vn(rho::DenseOpType; tol=1e-15)
145+
check_multiplicable(rho, rho)
145146
evals::Vector{ComplexF64} = eigvals(rho.data)
146147
entr = zero(eltype(rho))
147148
for d evals
@@ -163,9 +164,10 @@ The Renyi α-entropy of a density operator is defined as
163164
S_α(ρ) = 1/(1-α) \\log(Tr(ρ^α))
164165
```
165166
"""
166-
function entropy_renyi(rho::DenseOpType{B,B}, α::Integer=2) where B
167+
function entropy_renyi(rho::Operator, α::Integer=2)
167168
α < 0 && throw(ArgumentError("α-Renyi entropy is defined for α≥0, α≂̸1"))
168169
α == 1 && throw(ArgumentError("α-Renyi entropy is defined for α≥0, α≂̸1"))
170+
check_multiplicable(rho,rho)
169171

170172
return 1/(1-α) * log(tr(rho^α))
171173
end
@@ -185,7 +187,12 @@ F(ρ, σ) = Tr\\left(\\sqrt{\\sqrt{ρ}σ\\sqrt{ρ}}\\right),
185187
186188
where ``\\sqrt{ρ}=\\sum_n\\sqrt{λ_n}|ψ⟩⟨ψ|``.
187189
"""
188-
fidelity(rho::DenseOpType{B,B}, sigma::DenseOpType{B,B}) where {B} = tr(sqrt(sqrt(rho.data)*sigma.data*sqrt(rho.data)))
190+
function fidelity(rho::DenseOpType, sigma::DenseOpType)
191+
check_multiplicable(rho,rho)
192+
check_multiplicable(sigma,sigma)
193+
check_multiplicable(rho,sigma)
194+
tr(sqrt(sqrt(rho.data)*sigma.data*sqrt(rho.data)))
195+
end
189196

190197

191198
"""
@@ -195,7 +202,9 @@ Partial transpose of rho with respect to subsystem specified by indices.
195202
196203
The `indices` argument can be a single integer or a collection of integers.
197204
"""
198-
function ptranspose(rho::DenseOpType{B,B}, indices=1) where B<:CompositeBasis
205+
function ptranspose(rho::DenseOpType, indices=1)
206+
length(basis_l(rho)) == length(basis_r(rho)) || throw(ArgumentError())
207+
length(basis_l(rho)) > 1 || throw(ArgumentError())
199208
# adapted from qutip.partial_transpose (https://qutip.org/docs/4.0.2/modules/qutip/partial_transpose.html)
200209
# works as long as QuantumOptics.jl doesn't change the implementation of `tensor`, i.e. tensor(a,b).data = kron(b.data,a.data)
201210
nsys = length(basis_l(rho))
@@ -217,7 +226,7 @@ end
217226
218227
Peres-Horodecki criterion of partial transpose.
219228
"""
220-
PPT(rho::DenseOpType{B,B}, index) where B<:CompositeBasis = all(real.(eigvals(ptranspose(rho, index).data)) .>= 0.0)
229+
PPT(rho::DenseOpType, index) = all(real.(eigvals(ptranspose(rho, index).data)) .>= 0.0)
221230

222231

223232
"""
@@ -232,7 +241,7 @@ N(ρ) = \\frac{\\|ρᵀ\\|-1}{2},
232241
```
233242
where `ρᵀ` is the partial transpose.
234243
"""
235-
negativity(rho::DenseOpType{B,B}, index) where B<:CompositeBasis = 0.5*(tracenorm(ptranspose(rho, index)) - 1.0)
244+
negativity(rho::DenseOpType, index) = 0.5*(tracenorm(ptranspose(rho, index)) - 1.0)
236245

237246

238247
"""
@@ -245,7 +254,7 @@ N(ρ) = \\log₂\\|ρᵀ\\|,
245254
```
246255
where `ρᵀ` is the partial transpose.
247256
"""
248-
logarithmic_negativity(rho::DenseOpType{B,B}, index) where B<:CompositeBasis = log(2, tracenorm(ptranspose(rho, index)))
257+
logarithmic_negativity(rho::DenseOpType, index) = log(2, tracenorm(ptranspose(rho, index)))
249258

250259

251260
"""
@@ -277,35 +286,16 @@ entanglement_entropy(dm(ket)) = 2 * entanglement_entropy(ket)
277286
By default the computed entropy is the Von-Neumann entropy, but a different
278287
function can be provided (for example to compute the entanglement-renyi entropy).
279288
"""
280-
function entanglement_entropy(psi::Ket{B}, partition, entropy_fun=entropy_vn) where B<:CompositeBasis
281-
# check that sites are within the range
282-
@assert all(partition .<= length(psi.basis))
283-
284-
rho = ptrace(psi, partition)
285-
return entropy_fun(rho)
286-
end
289+
entanglement_entropy(psi::Ket, partition, entropy_fun=entropy_vn) = entropy_fun(ptrace(psi, partition))
287290

288-
function entanglement_entropy(rho::DenseOpType{B,B}, partition, args...) where {B<:CompositeBasis}
289-
# check that sites is within the range
290-
hilb = rho.basis_l
291-
N = length(hilb)
292-
all(partition .<= N) || throw(ArgumentError("Indices in partition must be within the bounds of the composite basis."))
293-
length(partition) <= N || throw(ArgumentError("Partition cannot include the whole system."))
291+
function entanglement_entropy(rho::DenseOpType, partition, args...)
292+
check_multiplicable(rho,rho)
293+
N = length(basis_l(rho))
294294

295295
# build the doubled hilbert space for the vectorised dm, normalized like a Ket.
296-
b_doubled = hilb^2
297-
rho_vec = normalize!(Ket(b_doubled, vec(rho.data)))
298-
299-
if partition isa Tuple
300-
partition_ = tuple(partition..., (partition.+N)...)
301-
else
302-
partition_ = vcat(partition, partition.+N)
303-
end
304-
305-
return entanglement_entropy(rho_vec,partition_,args...)
296+
rho_vec = normalize!(Ket(basis_l(rho)^2, vec(rho.data)))
297+
entanglement_entropy(rho_vec, [partition..., (partition.+N)...], args...)
306298
end
307299

308-
entanglement_entropy(state::Ket{B}, partition::Number, args...) where B<:CompositeBasis =
309-
entanglement_entropy(state, [partition], args...)
310-
entanglement_entropy(state::DenseOpType{B,B}, partition::Number, args...) where B<:CompositeBasis =
311-
entanglement_entropy(state, [partition], args...)
300+
entanglement_entropy(state::Ket, partition::Integer, args...) = entanglement_entropy(state, [partition], args...)
301+
entanglement_entropy(state::DenseOpType, partition::Integer, args...) = entanglement_entropy(state, [partition], args...)

src/operators.jl

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ using QuantumInterface: arithmetic_binary_error, arithmetic_unary_error, addnumb
2121
2222
Embed operator acting on a joint Hilbert space where missing indices are filled up with identity operators.
2323
"""
24-
function embed(bl::CompositeBasis, br::CompositeBasis,
25-
indices, op::T) where T<:DataOperator
24+
function embed(bl::Basis, br::Basis, indices, op::T) where T<:DataOperator
2625
(length(bl) == length(br)) || throw(ArgumentError("Must have length(bl) == length(br) in embed"))
2726
N = length(bl)
2827

@@ -70,15 +69,13 @@ function embed(bl::CompositeBasis, br::CompositeBasis,
7069
return unpermuted_op
7170
end
7271

73-
function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
74-
index::Integer, op::T) where T<:DataOperator
75-
76-
N = length(basis_l)
72+
function embed(bl::Basis, br::Basis, index::Integer, op::DataOperator)
73+
N = length(bl)
7774

7875
# Check stuff
79-
@assert length(basis_r) == N
80-
basis_l[index] == op.basis_l || throw(IncompatibleBases())
81-
basis_r[index] == op.basis_r || throw(IncompatibleBases())
76+
@assert length(br) == N
77+
bl[index] == basis_l(op) || throw(IncompatibleBases())
78+
br[index] == basis_r(op) || throw(IncompatibleBases())
8279
check_indices(N, index)
8380

8481
# Build data
@@ -89,17 +86,14 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
8986
while i > 0
9087
if i == index
9188
data = kron(data, op.data)
92-
i -= length(index)
9389
else
94-
bl = basis_l[i]
95-
br = basis_r[i]
96-
id = SparseMatrixCSC{Tnum}(I, dimension(bl), dimension(br))
90+
id = SparseMatrixCSC{Tnum}(I, dimension(bl[i]), dimension(br[i]))
9791
data = kron(data, id)
98-
i -= 1
9992
end
93+
i -= 1
10094
end
10195

102-
return Operator(basis_l, basis_r, data)
96+
return Operator(bl, br, data)
10397
end
10498

10599
"""
@@ -117,13 +111,13 @@ end
117111
# TODO upstream this one
118112
# expect(op::AbstractOperator{B,B}, state::AbstractKet{B}) where B = norm(op * state) ^ 2
119113

120-
function expect(indices, op::AbstractOperator, state::Ket{B}) where {B<:CompositeBasis}
114+
function expect(indices, op::AbstractOperator, state::Ket)
121115
N = length(basis(state))
122116
indices_ = complement(N, indices)
123117
expect(op, ptrace(state, indices_))
124118
end
125119

126-
expect(index::Integer, op::AbstractOperator, state::Ket{B}) where {B<:CompositeBasis} = expect([index], op, state)
120+
expect(index::Integer, op::AbstractOperator, state::Ket) = expect([index], op, state)
127121

128122
"""
129123
variance(op, state)
@@ -138,23 +132,23 @@ function variance(op::AbstractOperator, state::Ket)
138132
state.data'*(op*x).data - (state.data'*x.data)^2
139133
end
140134

141-
function variance(indices, op::AbstractOperator, state::Ket{B}) where {B<:CompositeBasis}
135+
function variance(indices, op::AbstractOperator, state::Ket)
142136
N = length(basis(state))
143137
indices_ = complement(N, indices)
144138
variance(op, ptrace(state, indices_))
145139
end
146140

147-
variance(index::Integer, op::AbstractOperator, state::Ket{B}) where {B<:CompositeBasis} = variance([index], op, state)
141+
variance(index::Integer, op::AbstractOperator, state::Ket) = variance([index], op, state)
148142

149143
# Helper functions to check validity of arguments
150144
function check_ptrace_arguments(a::AbstractOperator, indices)
151-
if !isa(a.basis_l, CompositeBasis) || !isa(a.basis_r, CompositeBasis)
152-
throw(ArgumentError("Partial trace can only be applied onto operators with composite bases."))
153-
end
154145
rank = length(basis_l(a))
155146
if rank != length(basis_r(a))
156147
throw(ArgumentError("Partial trace can only be applied onto operators wich have the same number of subsystems in the left basis and right basis."))
157148
end
149+
if rank < 2
150+
throw(ArgumentError("Partial trace can only be applied to operators over at least two subsystems."))
151+
end
158152
if rank == length(indices)
159153
throw(ArgumentError("Partial trace can't be used to trace out all subsystems - use tr() instead."))
160154
end

src/operators_dense.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,15 @@ function exp(op::T) where {B,T<:DenseOpType{B,B}}
246246
return DenseOperator(op.basis_l, op.basis_r, exp(op.data))
247247
end
248248

249-
function permutesystems(a::Operator{B1,B2}, perm) where {B1<:CompositeBasis,B2<:CompositeBasis}
249+
function permutesystems(a::Operator, perm)
250250
@assert length(a.basis_l) == length(a.basis_r) == length(perm)
251251
@assert isperm(perm)
252252
data = Base.ReshapedArray(a.data, (a.basis_l.shape..., a.basis_r.shape...), ())
253253
data = PermutedDimsArray(data, [perm; perm .+ length(perm)])
254254
data = Base.ReshapedArray(data, (length(a.basis_l), length(a.basis_r)), ())
255255
return Operator(permutesystems(a.basis_l, perm), permutesystems(a.basis_r, perm), copy(data))
256256
end
257-
permutesystems(a::AdjointOperator{B1,B2}, perm) where {B1<:CompositeBasis,B2<:CompositeBasis} = dagger(permutesystems(dagger(a),perm))
257+
permutesystems(a::AdjointOperator, perm) = dagger(permutesystems(dagger(a),perm))
258258

259259
identityoperator(::Type{S}, ::Type{T}, b1::Basis, b2::Basis) where {S<:DenseOpType,T<:Number} =
260260
Operator(b1, b2, Matrix{T}(I, dimension(b1), dimension(b2)))

src/pauli.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@ function choi_chi(Nl, Nr)
3636
end
3737

3838
# It's possible to get better asympotic speedups using, e.g. methods from
39+
# https://quantum-journal.org/papers/q-2024-09-05-1461/ (see appendices)
3940
# https://iopscience.iop.org/article/10.1088/1402-4896/ad6499
4041
# https://arxiv.org/abs/2411.00526
41-
# https://quantum-journal.org/papers/q-2024-09-05-1461/ (see appendices)
42+
# https://arxiv.org/abs/2408.06206
43+
# https://quantumcomputing.stackexchange.com/questions/31788/how-to-write-the-iswap-unitary-as-a-linear-combination-of-tensor-products-betw/31790#31790
44+
# So probably using https://github.com/JuliaMath/Hadamard.jl would be best
4245
function _pauli_comp_convert(op, rev)
4346
Nl, Nr = length(basis_l(basis_l(op))), length(basis_l(basis_r(op)))
4447
Vl = pauli_comp(Nl)

src/spinors.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ function embed(bl::SumBasis, br::SumBasis, indices, op::LazyDirectSum)
215215
end
216216
# TODO: embed for multiple LazyDirectums?
217217

218+
embed(bl::SumBasis, br::SumBasis, index::Integer, op::LazyDirectSum) = embed(bl,br,[index],op)
219+
218220
# Fast in-place multiplication
219221
function mul!(result::Ket{B1},M::LazyDirectSum{B1,B2},b::Ket{B2},alpha_,beta_) where {B1,B2}
220222
alpha = convert(ComplexF64, alpha_)

src/superoperators.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,12 @@ sprepost(A::Operator, B::Operator) = Operator(KetBraBasis(basis_l(A), basis_r(B)
6464
# return Operator(a1⊗b1, a2⊗b2, data)
6565
#end
6666

67+
# https://discourse.julialang.org/t/permuteddimsarray-slower-than-permutedims/46401
6768
function super_tensor(A, B)
6869
all, alr = basis_l(basis_l(A)), basis_r(basis_l(A))
6970
arl, arr = basis_l(basis_r(A)), basis_r(basis_r(A))
70-
bll, blr = basis_l(basis_l(A)), basis_r(basis_l(A))
71-
brl, brr = basis_l(basis_r(A)), basis_r(basis_r(A))
71+
bll, blr = basis_l(basis_l(B)), basis_r(basis_l(B))
72+
brl, brr = basis_l(basis_r(B)), basis_r(basis_r(B))
7273
data = kron(B.data, A.data)
7374
data = reshape(data, map(dimension, (all, bll, alr, blr, arl, brl, arr, brr)))
7475
data = PermutedDimsArray(data, (1, 3, 2, 4, 5, 6, 7, 8))

src/time_dependent_operator.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,12 @@ function dagger(op::TimeDependentSum)
201201
end
202202
adjoint(op::TimeDependentSum) = dagger(op)
203203

204-
function embed(basis_l::CompositeBasis, basis_r::CompositeBasis, i::Integer, o::TimeDependentSum)
205-
TimeDependentSum(coefficients(o), embed(basis_l, basis_r, i, static_operator(o)), o.current_time)
204+
function embed(bl::Basis, br::Basis, i::Integer, o::TimeDependentSum)
205+
TimeDependentSum(coefficients(o), embed(bl, br, i, static_operator(o)), o.current_time)
206206
end
207207

208-
function embed(basis_l::CompositeBasis, basis_r::CompositeBasis, indices, o::TimeDependentSum)
209-
TimeDependentSum(coefficients(o), embed(basis_l, basis_r, indices, static_operator(o)), o.current_time)
208+
function embed(bl::Basis, br::Basis, indices, o::TimeDependentSum)
209+
TimeDependentSum(coefficients(o), embed(bl, br, indices, static_operator(o)), o.current_time)
210210
end
211211

212212
function +(A::TimeDependentSum, B::TimeDependentSum)

test/test_metrics.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Test
22
using QuantumOpticsBase
3+
import QuantumInterface: IncompatibleBases
34
using SparseArrays, LinearAlgebra
45

56
@testset "metrics" begin
@@ -49,8 +50,8 @@ sigma = tensor(psi2, dagger(psi2))
4950
@test tracedistance(sigma, sigma) 0.
5051

5152
rho = spinup(b1) dagger(coherentstate(b2, 0.1))
52-
@test_throws ArgumentError tracedistance(rho, rho)
53-
@test_throws ArgumentError tracedistance_h(rho, rho)
53+
@test_throws IncompatibleBases tracedistance(rho, rho)
54+
@test_throws IncompatibleBases tracedistance_h(rho, rho)
5455

5556
@test tracedistance_nh(rho, rho) 0.
5657

test/test_particle.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ psi_x_fft2 = tensor((dagger.(psi0_p).*Tpx_sub)...)
277277
difference = (dense(Txp) - identityoperator(DenseOpType, Txp.basis_l)*Txp).data
278278
@test isapprox(difference, zero(difference); atol=1e-12)
279279
@test_throws AssertionError transform(tensor(basis_position...), tensor(basis_position...))
280-
@test_throws QuantumOpticsBase.IncompatibleBases transform(SpinBasis(1//2)^2, SpinBasis(1//2)^2)
280+
@test_throws MethodError transform(SpinBasis(1//2)^2, SpinBasis(1//2)^2)
281281

282282
@test dense(Txp) == dense(Txp_sub[1] Txp_sub[2])
283283

0 commit comments

Comments
 (0)