Skip to content

Commit 16ab274

Browse files
committed
SpinBasis can represent uniform tensor products
1 parent ba4620e commit 16ab274

File tree

3 files changed

+99
-39
lines changed

3 files changed

+99
-39
lines changed

src/bases.jl

Lines changed: 89 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ subsystems.
6060
6161
See also [`CompositeBasis`](@ref).
6262
"""
63-
Base.getindex(b::Basis, i) = i==1 ? b : throw(BoundsError("attempted to access a nonexistent subsystem basis"))
63+
Base.getindex(b::Basis, i) = i==1 ? b : throw(BoundsError(b,i))
64+
Base.firstindex(b::Basis) = 1
65+
Base.lastindex(b::Basis) = length(b)
6466

6567
Base.iterate(b::Basis, state=1) = state > length(b) ? nothing : (b[state], state+1)
6668

@@ -109,19 +111,34 @@ Stores the subbases in a vector and creates the shape vector directly from the
109111
dimensions of these subbases. Instead of creating a CompositeBasis directly,
110112
`tensor(b1, b2...)` or `b1 ⊗ b2 ⊗ …` should be used.
111113
"""
112-
struct CompositeBasis{B<:Basis,S<:Integer} <: Basis
113-
shape::Vector{S}
114+
struct CompositeBasis{B<:Basis} <: Basis
114115
bases::Vector{B}
116+
shape::Vector{Int}
117+
lengths::Vector{Int}
118+
N::Int
119+
D::Int
120+
function CompositeBasis(bases::Vector{B}) where B<:Basis
121+
# to enable this check the the lazy operators in QuantumOpticsBase need to be changed
122+
#length(bases) > 1 || throw(ArgumentError("CompositeBasis must only be used for composite systems"))
123+
shape_ = mapreduce(shape, vcat, bases)
124+
lengths = cumsum(map(length, bases))
125+
new{B}(bases, shape_, lengths, lengths[end], prod(shape_))
126+
end
115127
end
116-
CompositeBasis(bases) = CompositeBasis([dimension(b) for b in bases], bases)
117128
CompositeBasis(bases::Basis...) = CompositeBasis([bases...])
118129
CompositeBasis(bases::Tuple) = CompositeBasis([bases...])
119130

120131
Base.:(==)(b1::CompositeBasis, b2::CompositeBasis) = all(((i, j),) -> i == j, zip(b1.bases, b2.bases))
121-
Base.length(b::CompositeBasis) = length(b.bases)
122-
Base.getindex(b::CompositeBasis, i) = getindex(b.bases, i)
132+
Base.length(b::CompositeBasis) = b.N
133+
function Base.getindex(b::CompositeBasis, i::Integer)
134+
(i < 1 || i > b.N) && throw(BoundsError(b,i))
135+
bases_idx = findfirst(l -> i<=l, b.lengths)
136+
inner_idx = i - (bases_idx == 1 ? 0 : b.lengths[bases_idx-1])
137+
b.bases[bases_idx][inner_idx]
138+
end
139+
Base.getindex(b::CompositeBasis, indices) = [b[i] for i in indices]
123140
shape(b::CompositeBasis) = b.shape
124-
dimension(b::CompositeBasis) = prod(b.shape)
141+
dimension(b::CompositeBasis) = b.D
125142

126143
"""
127144
tensor(x::Basis, y::Basis, z::Basis...)
@@ -131,13 +148,40 @@ Create a [`CompositeBasis`](@ref) from the given bases.
131148
Any given CompositeBasis is expanded so that the resulting CompositeBasis never
132149
contains another CompositeBasis.
133150
"""
134-
tensor(b1::Basis, b2::Basis) = CompositeBasis([dimension(b1), dimension(b2)], [b1, b2])
135-
tensor(b1::CompositeBasis, b2::CompositeBasis) = CompositeBasis([b1.shape; b2.shape], [b1.bases; b2.bases])
136-
tensor(b1::CompositeBasis, b2::Basis) = CompositeBasis([b1.shape; dimension(b2)], [b1.bases; b2])
137-
tensor(b1::Basis, b2::CompositeBasis) = CompositeBasis([dimension(b1); b2.shape], [b1; b2.bases])
151+
tensor(b1::Basis, b2::Basis) = CompositeBasis([b1, b2])
138152
tensor(bases::Basis...) = reduce(tensor, bases)
139153
tensor(basis::Basis) = basis
140154

155+
function tensor(b1::CompositeBasis, b2::CompositeBasis)
156+
if typeof(b1.bases[end]) == typeof(b2.bases[1])
157+
t = tensor(b1.bases[end], b2.bases[1])
158+
if !(t isa CompositeBasis)
159+
return CompositeBasis([b1.bases[1:end-1]; t; b2.bases[2:end]])
160+
end
161+
end
162+
return CompositeBasis([b1.bases; b2.bases])
163+
end
164+
165+
function tensor(b1::CompositeBasis, b2::Basis)
166+
if b1.bases[end] isa typeof(b2)
167+
t = tensor(b1.bases[end], b2)
168+
if !(t isa CompositeBasis)
169+
return CompositeBasis([b1.bases[1:end-1]; t])
170+
end
171+
end
172+
return CompositeBasis([b1.bases; b2])
173+
end
174+
175+
function tensor(b1::Basis, b2::CompositeBasis)
176+
if b2.bases[1] isa typeof(b1)
177+
t = tensor(b1, b2.bases[1])
178+
if !(t isa CompositeBasis)
179+
return CompositeBasis([t; b2[2:end]])
180+
end
181+
end
182+
return CompositeBasis([b1; b2.bases])
183+
end
184+
141185
Base.:^(b::Basis, N::Integer) = tensor_pow(b, N)
142186

143187
"""
@@ -185,7 +229,8 @@ directsum(b1::Basis, b2::SumBasis) = SumBasis([dimension(b1); b2.shape], [b1; b2
185229
directsum(bases::Basis...) = reduce(directsum, bases)
186230
directsum(basis::Basis) = basis
187231

188-
embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops)
232+
# TODO: what to do about embed for SumBasis?
233+
#embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops)
189234

190235
##
191236
# Basis checks
@@ -267,13 +312,13 @@ Reduced basis, state or operator on the specified subsystems.
267312
The `indices` argument, which can be a single integer or a vector of integers,
268313
specifies which subsystems are kept. At least one index must be specified.
269314
"""
270-
function reduced(b::CompositeBasis, indices)
315+
function reduced(b::Basis, indices)
271316
if length(indices)==0
272317
throw(ArgumentError("At least one subsystem must be specified in reduced."))
273318
elseif length(indices)==1
274-
return b.bases[indices[1]]
319+
return b[indices[1]]
275320
else
276-
return CompositeBasis(b.shape[indices], b.bases[indices])
321+
return tensor(b[indices]...)
277322
end
278323
end
279324

@@ -287,13 +332,13 @@ specifies which subsystems are traced out. The number of indices has to be
287332
smaller than the number of subsystems, i.e. it is not allowed to perform a
288333
full trace.
289334
"""
290-
function ptrace(b::CompositeBasis, indices)
291-
J = [i for i in 1:length(b.bases) if i indices]
335+
function ptrace(b::Basis, indices)
336+
J = [i for i in 1:length(b) if i indices]
292337
length(J) > 0 || throw(ArgumentError("Tracing over all indices is not allowed in ptrace."))
293338
reduced(b, J)
294339
end
295340

296-
_index_complement(b::CompositeBasis, indices) = complement(length(b.bases), indices)
341+
_index_complement(b::Basis, indices) = complement(length(b), indices)
297342
reduced(a, indices) = ptrace(a, _index_complement(basis(a), indices))
298343

299344
"""
@@ -304,10 +349,10 @@ Change the ordering of the subsystems of the given object.
304349
For a permutation vector `[2,1,3]` and a given object with basis `[b1, b2, b3]`
305350
this function results in `[b2, b1, b3]`.
306351
"""
307-
function permutesystems(b::CompositeBasis, perm)
352+
function permutesystems(b::Basis, perm)
308353
(length(b) == length(perm)) || throw(ArgumentError("Must have length(b) == length(perm) in permutesystems"))
309354
isperm(perm) || throw(ArgumentError("Must pass actual permeutation to permutesystems"))
310-
CompositeBasis(b.shape[perm], b.bases[perm])
355+
tensor(b.bases[perm]...)
311356
end
312357

313358

@@ -374,9 +419,9 @@ Base.:(==)(b1::NLevelBasis, b2::NLevelBasis) = b1.N == b2.N
374419
dimension(b::NLevelBasis) = b.N
375420

376421
"""
377-
SpinBasis(n)
422+
SpinBasis(n, N=1)
378423
379-
Basis for spin-n particles.
424+
Basis for spin-`n` particles over `N` systems.
380425
381426
The basis can be created for arbitrary spin numbers by using a rational number,
382427
e.g. `SpinBasis(3//2)`. The Pauli operators are defined for all possible spin
@@ -385,19 +430,31 @@ for a `SpinBasis`.
385430
"""
386431
struct SpinBasis{T<:Integer} <: Basis
387432
spinnumber::Rational{T}
388-
function SpinBasis(spinnumber::Rational{T}) where T
433+
D::T
434+
N::T
435+
function SpinBasis(spinnumber::Rational{T}, N=1) where T
389436
n = numerator(spinnumber)
390437
d = denominator(spinnumber)
391438
d==2 || d==1 || throw(ArgumentError("Can only construct integer or half-integer spin basis"))
392439
n >= 0 || throw(ArgumentError("Can only construct positive spin basis"))
393-
N = numerator(spinnumber*2 + 1)
394-
new{T}(spinnumber)
440+
D = numerator(spinnumber*2 + 1)
441+
new{T}(spinnumber, D, N)
395442
end
396443
end
397444
SpinBasis(spinnumber) = SpinBasis(convert(Rational{Int}, spinnumber))
398445

399-
Base.:(==)(b1::SpinBasis, b2::SpinBasis) = b1.spinnumber==b2.spinnumber
400-
dimension(b::SpinBasis) = numerator(b.spinnumber*2 + 1)
446+
Base.:(==)(b1::SpinBasis, b2::SpinBasis) = b1.D==b2.D && b1.N == b2.N
447+
Base.length(b::SpinBasis) = b.N
448+
Base.getindex(b::SpinBasis, i) = SpinBasis(b.spinnumber, length(i))
449+
shape(b::SpinBasis) = fill(b.D, b.N)
450+
dimension(b::SpinBasis) = b.D^b.N
451+
function tensor(b1::SpinBasis, b2::SpinBasis)
452+
if b1.spinnumber == b2.spinnumber
453+
return SpinBasis(b1.spinnumber, b1.N+b2.N)
454+
else
455+
return CompositeBasis([b1, b2])
456+
end
457+
end
401458

402459
"""
403460
spinnumber(b::SpinBasis)
@@ -554,6 +611,9 @@ function show(stream::IO, x::SpinBasis)
554611
else
555612
write(stream, "Spin($n/$d)")
556613
end
614+
if x.N > 1
615+
write(stream, "^$(x.N)")
616+
end
557617
end
558618

559619
function show(stream::IO, x::FockBasis)
@@ -584,9 +644,9 @@ function show(stream::IO, x::KetBraBasis)
584644
end
585645

586646
function show(stream::IO, x::PauliBasis)
587-
write(stream, "Pauli(N=$(x.N)")
647+
write(stream, "Pauli(N=$(x.N))")
588648
end
589649

590650
function show(stream::IO, x::HWPauliBasis)
591-
write(stream, "Pauli($(x.shape)")
651+
write(stream, "Pauli($(x.shape))")
592652
end

src/embed_permute.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
`operators` is a dictionary `Dict{Vector{Int}, AbstractOperator}`. The integer vector
66
specifies in which subsystems the corresponding operator is defined.
77
"""
8-
function embed(bl::CompositeBasis, br::CompositeBasis,
8+
function embed(bl::Basis, br::Basis,
99
operators::Dict{<:Vector{<:Integer}, T}) where T<:AbstractOperator
1010
(length(bl) == length(br)) || throw(ArgumentError("Must have length(bl) == length(br) in embed"))
1111
N = length(bl)::Int # type assertion to help type inference
@@ -35,24 +35,24 @@ function embed(bl::CompositeBasis, br::CompositeBasis,
3535
return permutesystems(op, perm)
3636
end
3737
end
38-
embed(basis_l::CompositeBasis, basis_r::CompositeBasis, operators::Dict{<:Integer, T}; kwargs...) where {T<:AbstractOperator} = embed(basis_l, basis_r, Dict([i]=>op_i for (i, op_i) in operators); kwargs...)
39-
embed(basis::CompositeBasis, operators::Dict{<:Integer, T}; kwargs...) where {T<:AbstractOperator} = embed(basis, basis, operators; kwargs...)
40-
embed(basis::CompositeBasis, operators::Dict{<:Vector{<:Integer}, T}; kwargs...) where {T<:AbstractOperator} = embed(basis, basis, operators; kwargs...)
38+
embed(bl::Basis, br::Basis, operators::Dict{<:Integer, T}; kwargs...) where {T<:AbstractOperator} = embed(bl, br, Dict([i]=>op_i for (i, op_i) in operators); kwargs...)
39+
embed(b::Basis, operators::Dict{<:Integer, T}; kwargs...) where {T<:AbstractOperator} = embed(b, b, operators; kwargs...)
40+
embed(b::Basis, operators::Dict{<:Vector{<:Integer}, T}; kwargs...) where {T<:AbstractOperator} = embed(b, b, operators; kwargs...)
4141

4242
# The dictionary implementation works for non-DataOperators
43-
embed(basis_l::CompositeBasis, basis_r::CompositeBasis, indices, op::T) where T<:AbstractOperator = embed(basis_l, basis_r, Dict(indices=>op))
43+
embed(bl::Basis, br::Basis, indices, op::T) where T<:AbstractOperator = embed(bl, br, Dict(indices=>op))
4444

45-
embed(basis_l::CompositeBasis, basis_r::CompositeBasis, index::Integer, op::AbstractOperator) = embed(basis_l, basis_r, index, [op])
46-
embed(basis::CompositeBasis, indices, operators::Vector{T}) where {T<:AbstractOperator} = embed(basis, basis, indices, operators)
47-
embed(basis::CompositeBasis, indices, op::AbstractOperator) = embed(basis, basis, indices, op)
45+
embed(bl::Basis, br::Basis, index::Integer, op::AbstractOperator) = embed(bl, br, index, [op])
46+
embed(b::Basis, indices, operators::Vector{T}) where {T<:AbstractOperator} = embed(b, b, indices, operators)
47+
embed(b::Basis, indices, op::AbstractOperator) = embed(b, b, indices, op)
4848

4949

5050
"""
5151
embed(basis1[, basis2], indices::Vector, operators::Vector)
5252
5353
Tensor product of operators where missing indices are filled up with identity operators.
5454
"""
55-
function embed(bl::CompositeBasis, br::CompositeBasis,
55+
function embed(bl::Basis, br::Basis,
5656
indices, operators::Vector{T}) where T<:AbstractOperator
5757

5858
check_embed_indices(indices) || throw(ArgumentError("Must have unique indices in embed"))

test/test_bases.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ comp_b2 = tensor(b1, b1, b2)
2828

2929
@test b1^3 == CompositeBasis(b1, b1, b1)
3030
@test (b1b2)^2 == CompositeBasis(b1, b2, b1, b2)
31-
@test_throws ArgumentError b1^(0)
31+
@test_throws DomainError b1^(0)
3232

3333
comp_b1_b2 = tensor(comp_b1, comp_b2)
3434
@test shape(comp_b1_b2) == [d1, d2, d1, d1, d2]

0 commit comments

Comments
 (0)