Skip to content

Commit f153af8

Browse files
committed
Implement new basis interface
1 parent 6ea405b commit f153af8

37 files changed

+330
-328
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33
## v1.0.0
44

5-
- First release, implementing necessary changes from reworking abstract types in `QuantumInterface`.
5+
- First release, implementing necessary changes to be compatible with the breaking release of `QuantumInterface` 0.4.0.
66

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "QuantumOpticsBase"
22
uuid = "4f57444f-1401-5e15-980d-4471b28d5678"
3-
version = "0.6.0"
3+
version = "1.0.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

docs/src/api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ QuantumOpticsBase.check_samebases
141141
```
142142

143143
```@docs
144-
@samebases
144+
@compatiblebases
145145
```
146146

147147
```@docs
@@ -640,4 +640,4 @@ lazytensor_cachesize
640640

641641
```@docs
642642
lazytensor_clear_cache
643-
```
643+
```

src/QuantumOpticsBase.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import QuantumInterface: dagger, directsum, ⊕, dm, embed, nsubsystems, expect,
1010
# index helpers
1111
import QuantumInterface: complement, remove, shiftremove, reducedindices!, check_indices, check_sortedindices, check_embed_indices
1212

13-
export Basis, GenericBasis, CompositeBasis, basis,
14-
tensor, , permutesystems, @samebases,
13+
export Basis, GenericBasis, CompositeBasis, basis, basis_l, basis_r,
14+
tensor, , permutesystems, @compatiblebases,
1515
#states
1616
StateVector, Bra, Ket, basisstate, sparsebasisstate, norm,
1717
dagger, normalize, normalize!,

src/bases.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
import QuantumInterface: Basis, basis, GenericBasis, CompositeBasis,
2-
equal_shape, IncompatibleBases, @samebases, samebases, check_samebases,
3-
multiplicable, check_multiplicable, reduced, ptrace, permutesystems
2+
equal_shape, IncompatibleBases, @compatiblebases, samebases, check_samebases,
3+
addible, check_addible, multiplicable, check_multiplicable, reduced, ptrace, permutesystems

src/charge.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct ChargeBasis{T} <: Basis
2828
end
2929

3030
Base.:(==)(b1::ChargeBasis, b2::ChargeBasis) = (b1.ncut == b2.ncut)
31+
Base.length(b::ChargeBasis) = b.dim
3132

3233
"""
3334
ShiftedChargeBasis(nmin, nmax) <: Basis
@@ -50,6 +51,7 @@ end
5051

5152
Base.:(==)(b1::ShiftedChargeBasis, b2::ShiftedChargeBasis) =
5253
(b1.nmin == b2.nmin && b1.nmax == b2.nmax)
54+
Base.length(b::ShiftedChargeBasis) = b.dim
5355

5456
"""
5557
chargestate([T=ComplexF64,] b::ChargeBasis, n)

src/manybody.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ end
5050
ManyBodyBasis(onebodybasis::B, occupations::O) where {B,O} = ManyBodyBasis{B,O}(onebodybasis, occupations)
5151
ManyBodyBasis(onebodybasis::B, occupations::Vector{T}) where {B,T} = ManyBodyBasis(onebodybasis, SortedVector(occupations))
5252

53+
==(b1::ManyBodyBasis, b2::ManyBodyBasis) = b1.occupations_hash == b2.occupations_hash && b1.onebodybasis == b2.onebodybasis
54+
Base.length(b::ManyBodyBasis) = length(b.occupations)
55+
56+
5357
allocate_buffer(occ) = similar(occ)
5458
allocate_buffer(mb::ManyBodyBasis) = allocate_buffer(first(mb.occupations))
5559

@@ -89,8 +93,6 @@ bosonstates(T::Type, Nmodes::Int, Nparticles::Vector{Int}) = union((bosonstates(
8993
bosonstates(T::Type, onebodybasis::Basis, Nparticles) = bosonstates(T, length(onebodybasis), Nparticles)
9094
bosonstates(arg1, arg2) = bosonstates(OccupationNumbers{BosonStatistics,Int}, arg1, arg2)
9195

92-
==(b1::ManyBodyBasis, b2::ManyBodyBasis) = b1.occupations_hash == b2.occupations_hash && b1.onebodybasis == b2.onebodybasis
93-
9496
"""
9597
basisstate([T=ComplexF64,] mb::ManyBodyBasis, occupation::Vector)
9698

src/metrics.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,15 @@ The `indices` argument can be a single integer or a collection of integers.
199199
function ptranspose(rho::DenseOpType{B,B}, indices=1) where B<:CompositeBasis
200200
# adapted from qutip.partial_transpose (https://qutip.org/docs/4.0.2/modules/qutip/partial_transpose.html)
201201
# works as long as QuantumOptics.jl doesn't change the implementation of `tensor`, i.e. tensor(a,b).data = kron(b.data,a.data)
202-
nsys = length(rho.basis_l.shape)
202+
nsys = nsubsystems(basis_l(rho))
203203
mask = ones(Int, nsys)
204204
mask[collect(indices)] .+= 1
205205
pt_dims = reshape(1:2*nsys, (nsys,2)) # indices of the operator viewed as a tensor with 2nsys legs
206206
pt_idx = [[pt_dims[i,mask[i]] for i = 1 : nsys]; [pt_dims[i,3-mask[i]] for i = 1 : nsys] ] # permute the legs on the subsystem of `indices`
207207
# reshape the operator data into a 2nsys-legged tensor and shape it back with the legs permuted
208-
data = reshape(permutedims(reshape(rho.data, Tuple([rho.basis_l.shape; rho.basis_r.shape])), pt_idx), size(rho.data))
208+
data = reshape(permutedims(reshape(rho.data, Tuple([size(basis_l(rho)); size(basis_r(rho))])), pt_idx), size(rho.data))
209209

210-
return DenseOperator(rho.basis_l,data)
210+
return DenseOperator(basis_l(rho),data)
211211

212212
end
213213

@@ -276,7 +276,7 @@ function can be provided (for example to compute the entanglement-renyi entropy)
276276
"""
277277
function entanglement_entropy(psi::Ket{B}, partition, entropy_fun=entropy_vn) where B<:CompositeBasis
278278
# check that sites are within the range
279-
@assert all(partition .<= length(psi.basis.bases))
279+
@assert all(partition .<= nsubsystems(psi.basis))
280280

281281
rho = ptrace(psi, partition)
282282
return entropy_fun(rho)
@@ -285,17 +285,18 @@ end
285285
function entanglement_entropy(rho::DenseOpType{B,B}, partition, args...) where {B<:CompositeBasis}
286286
# check that sites is within the range
287287
hilb = rho.basis_l
288-
all(partition .<= length(hilb.bases)) || throw(ArgumentError("Indices in partition must be within the bounds of the composite basis."))
289-
length(partition) <= length(hilb.bases) || throw(ArgumentError("Partition cannot include the whole system."))
288+
N = nsubsystems(hilb)
289+
all(partition .<= N) || throw(ArgumentError("Indices in partition must be within the bounds of the composite basis."))
290+
length(partition) <= N || throw(ArgumentError("Partition cannot include the whole system."))
290291

291292
# build the doubled hilbert space for the vectorised dm, normalized like a Ket.
292293
b_doubled = hilb^2
293294
rho_vec = normalize!(Ket(b_doubled, vec(rho.data)))
294295

295296
if partition isa Tuple
296-
partition_ = tuple(partition..., (partition.+length(hilb.bases))...)
297+
partition_ = tuple(partition..., (partition.+N)...)
297298
else
298-
partition_ = vcat(partition, partition.+length(hilb.bases))
299+
partition_ = vcat(partition, partition.+N)
299300
end
300301

301302
return entanglement_entropy(rho_vec,partition_,args...)

src/operators.jl

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ Abstract type for operators with a data field.
99
This is an abstract type for operators that have a direct matrix representation
1010
stored in their `.data` field.
1111
"""
12-
abstract type BLROperator{BL,BR} <: AbstractOperator end
13-
abstract type DataOperator{BL,BR} <: BLROperator{BL,BR} end
12+
abstract type DataOperator{BL,BR} <: AbstractOperator end
1413

1514

1615
# Common error messages
@@ -24,14 +23,14 @@ Embed operator acting on a joint Hilbert space where missing indices are filled
2423
"""
2524
function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
2625
indices, op::T) where T<:DataOperator
27-
N = length(basis_l.bases)
28-
@assert length(basis_r.bases) == N
26+
N = nsubsystems(basis_l)
27+
@assert nsubsystems(basis_r) == N
2928

30-
reduce(tensor, basis_l.bases[indices]) == op.basis_l || throw(IncompatibleBases())
31-
reduce(tensor, basis_r.bases[indices]) == op.basis_r || throw(IncompatibleBases())
29+
reduce(tensor, basis_l[indices]) == op.basis_l || throw(IncompatibleBases())
30+
reduce(tensor, basis_r[indices]) == op.basis_r || throw(IncompatibleBases())
3231

33-
index_order = [idx for idx in 1:length(basis_l.bases) if idx indices]
34-
all_operators = AbstractOperator[identityoperator(T, eltype(op), basis_l.bases[i], basis_r.bases[i]) for i in index_order]
32+
index_order = [idx for idx in 1:N if idx indices]
33+
all_operators = AbstractOperator[identityoperator(T, eltype(op), basis_l[i], basis_r[i]) for i in index_order]
3534

3635
for idx in indices
3736
pushfirst!(index_order, idx)
@@ -45,8 +44,8 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
4544

4645
# Reorient the matrix to act in the correctly ordered basis.
4746
# Get the dimensions necessary for index permuting.
48-
dims_l = [b.shape[1] for b in basis_l.bases]
49-
dims_r = [b.shape[1] for b in basis_r.bases]
47+
dims_l = size(basis_l)
48+
dims_r = size(basis_r)
5049

5150
# Get the order of indices to use in the first reshape. Julia indices go in
5251
# reverse order.
@@ -74,12 +73,12 @@ end
7473
function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
7574
index::Integer, op::T) where T<:DataOperator
7675

77-
N = length(basis_l.bases)
76+
N = nsubsystems(basis_l)
7877

7978
# Check stuff
80-
@assert N==length(basis_r.bases)
81-
basis_l.bases[index] == op.basis_l || throw(IncompatibleBases())
82-
basis_r.bases[index] == op.basis_r || throw(IncompatibleBases())
79+
@assert nsubsystems(basis_r) == N
80+
basis_l[index] == op.basis_l || throw(IncompatibleBases())
81+
basis_r[index] == op.basis_r || throw(IncompatibleBases())
8382
check_indices(N, index)
8483

8584
# Build data
@@ -92,8 +91,8 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
9291
data = kron(data, op.data)
9392
i -= length(index)
9493
else
95-
bl = basis_l.bases[i]
96-
br = basis_r.bases[i]
94+
bl = basis_l[i]
95+
br = basis_r[i]
9796
id = SparseMatrixCSC{Tnum}(I, length(bl), length(br))
9897
data = kron(data, id)
9998
i -= 1
@@ -110,18 +109,21 @@ Expectation value of the given operator `op` for the specified `state`.
110109
111110
`state` can either be a (density) operator or a ket.
112111
"""
113-
expect(op::BLROperator{B,B}, state::Ket{B}) where B = dot(state.data, (op * state).data)
112+
function expect(op::AbstractOperator, state::Ket)
113+
check_multiplicable(op,op); check_multiplicable(op,state)
114+
dot(state.data, (op * state).data)
115+
end
114116

115117
# TODO upstream this one
116118
# expect(op::AbstractOperator{B,B}, state::AbstractKet{B}) where B = norm(op * state) ^ 2
117119

118-
function expect(indices, op::BLROperator{B,B}, state::Ket{B2}) where {B,B2<:CompositeBasis}
119-
N = length(state.basis.shape)
120+
function expect(indices, op::AbstractOperator, state::Ket{B}) where {B<:CompositeBasis}
121+
N = nsubsystems(basis(state))
120122
indices_ = complement(N, indices)
121123
expect(op, ptrace(state, indices_))
122124
end
123125

124-
expect(index::Integer, op::BLROperator{B,B}, state::Ket{B2}) where {B,B2<:CompositeBasis} = expect([index], op, state)
126+
expect(index::Integer, op::AbstractOperator, state::Ket{B}) where {B<:CompositeBasis} = expect([index], op, state)
125127

126128
"""
127129
variance(op, state)
@@ -130,44 +132,42 @@ Variance of the given operator `op` for the specified `state`.
130132
131133
`state` can either be a (density) operator or a ket.
132134
"""
133-
function variance(op::BLROperator{B,B}, state::Ket{B}) where B
135+
function variance(op::AbstractOperator, state::Ket)
136+
check_multiplicable(op,op); check_multiplicable(op,state)
134137
x = op*state
135138
state.data'*(op*x).data - (state.data'*x.data)^2
136139
end
137140

138-
function variance(indices, op::BLROperator{B,B}, state::Ket{BC}) where {B,BC<:CompositeBasis}
139-
N = length(state.basis.shape)
141+
function variance(indices, op::AbstractOperator, state::Ket{B}) where {B<:CompositeBasis}
142+
N = nsubsystems(basis(state))
140143
indices_ = complement(N, indices)
141144
variance(op, ptrace(state, indices_))
142145
end
143146

144-
variance(index::Integer, op::BLROperator{B,B}, state::Ket{BC}) where {B,BC<:CompositeBasis} = variance([index], op, state)
147+
variance(index::Integer, op::AbstractOperator, state::Ket{B}) where {B<:CompositeBasis} = variance([index], op, state)
145148

146149
# Helper functions to check validity of arguments
147150
function check_ptrace_arguments(a::AbstractOperator, indices)
148151
if !isa(a.basis_l, CompositeBasis) || !isa(a.basis_r, CompositeBasis)
149152
throw(ArgumentError("Partial trace can only be applied onto operators with composite bases."))
150153
end
151-
rank = length(a.basis_l.shape)
152-
if rank != length(a.basis_r.shape)
154+
rank = nsubsystems(basis_l(a))
155+
if rank != nsubsystems(basis_r(a))
153156
throw(ArgumentError("Partial trace can only be applied onto operators wich have the same number of subsystems in the left basis and right basis."))
154157
end
155158
if rank == length(indices)
156159
throw(ArgumentError("Partial trace can't be used to trace out all subsystems - use tr() instead."))
157160
end
158-
check_indices(length(a.basis_l.shape), indices)
161+
check_indices(nsubsystems(basis_l(a)), indices)
159162
for i=indices
160-
if a.basis_l.shape[i] != a.basis_r.shape[i]
163+
if size(basis_l(a))[i] != size(basis_r(a))[i]
161164
throw(ArgumentError("Partial trace can only be applied onto subsystems that have the same left and right dimension."))
162165
end
163166
end
164167
end
165168
function check_ptrace_arguments(a::StateVector, indices)
166-
if length(basis(a).shape) == length(indices)
169+
if nsubsystems(basis(a)) == length(indices)
167170
throw(ArgumentError("Partial trace can't be used to trace out all subsystems - use tr() instead."))
168171
end
169-
check_indices(length(basis(a).shape), indices)
172+
check_indices(nsubsystems(basis(a)), indices)
170173
end
171-
172-
multiplicable(a::AbstractOperator, b::Ket) = multiplicable(a.basis_r, b.basis)
173-
multiplicable(a::Bra, b::AbstractOperator) = multiplicable(a.basis, b.basis_l)

0 commit comments

Comments
 (0)