Skip to content

Commit e24807a

Browse files
committed
Implement basis interface proposed in #40
1 parent 1672e2f commit e24807a

File tree

7 files changed

+166
-49
lines changed

7 files changed

+166
-49
lines changed

src/QuantumInterface.jl

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,88 @@ module QuantumInterface
77
"""
88
basis(a)
99
10-
Return the basis of an object.
10+
Return the basis of a quantum object.
1111
12-
If it's ambiguous, e.g. if an operator has a different left and right basis,
13-
an [`IncompatibleBases`](@ref) error is thrown.
12+
If it's ambiguous, e.g. if an operator has a different
13+
left and right basis, an [`IncompatibleBases`](@ref) error is thrown.
14+
15+
See [`StateVector`](@ref) and [`AbstractOperator`](@ref)
1416
"""
1517
function basis end
1618

19+
"""
20+
basis_l(a)
21+
22+
Return the left basis of an operator.
23+
"""
24+
function basis_l end
25+
26+
"""
27+
basis_r(a)
28+
29+
Return the right basis of an operator.
30+
"""
31+
function basis_r end
32+
1733
"""
1834
Exception that should be raised for an illegal algebraic operation.
1935
"""
2036
mutable struct IncompatibleBases <: Exception end
2137

38+
#function bases end
39+
40+
function spinnumber end
41+
42+
function cutoff end
43+
44+
function offset end
2245

2346
##
2447
# Standard methods
2548
##
2649

50+
"""
51+
multiplicable(a, b)
52+
53+
Check if any two subtypes of `StateVector` or `AbstractOperator`,
54+
can be multiplied in the given order.
55+
"""
56+
function multiplicable end
57+
58+
"""
59+
check_multiplicable(a, b)
60+
61+
Throw an [`IncompatibleBases`](@ref) error if the objects are
62+
not multiplicable as determined by `multiplicable(a, b)`.
63+
64+
If the macro `@compatiblebases` is used anywhere up the call stack,
65+
this check is disabled.
66+
"""
67+
function check_multiplicable end
68+
69+
"""
70+
addible(a, b)
71+
72+
Check if any two subtypes of `StateVector` or `AbstractOperator`
73+
can be added together.
74+
75+
Spcefically this checks whether the left basis of a is equal
76+
to the left basis of b and whether the right basis of a is equal
77+
to the right basis of b.
78+
"""
79+
function addible end
80+
81+
"""
82+
check_addible(a, b)
83+
84+
Throw an [`IncompatibleBases`](@ref) error if the objects are
85+
not addible as determined by `addible(a, b)`.
86+
87+
If the macro `@compatiblebases` is used anywhere up the call stack,
88+
this check is disabled.
89+
"""
90+
function check_addible end
91+
2792
function apply! end
2893

2994
function dagger end

src/abstract_types.jl

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,40 @@
11
"""
2-
Abstract base class for all specialized bases.
2+
Abstract type for all specialized bases of a Hilbert space.
33
4-
The Basis class is meant to specify a basis of the Hilbert space of the
5-
studied system. Besides basis specific information all subclasses must
6-
implement a shape variable which indicates the dimension of the used
7-
Hilbert space. For a spin-1/2 Hilbert space this would be the
8-
vector `[2]`. A system composed of two spins would then have a
9-
shape vector `[2 2]`.
4+
The `Basis` type specifies an orthonormal basis for the Hilbert
5+
space of the studied system. All subtypes must implement `Base.:(==)`,
6+
and `Base.size`. `size` should return a tuple representing the total dimension
7+
of the Hilbert space with any tensor product structure the basis has such that
8+
`length(b::Basis) = prod(size(b))` gives the total Hilbert dimension
109
11-
Composite systems can be defined with help of the [`CompositeBasis`](@ref)
12-
class.
10+
Composite systems can be defined with help of [`CompositeBasis`](@ref).
11+
12+
All relevant properties of subtypes of `Basis` defined in `QuantumInterface`
13+
should be accessed using their documented functions and should not
14+
assume anything about the internal representation of instances of these
15+
types (i.e. don't access the struct's fields directly).
1316
"""
1417
abstract type Basis end
1518

1619
"""
17-
Abstract base class for `Bra` and `Ket` states.
20+
Abstract type for `Bra` and `Ket` states.
1821
19-
The state vector class stores the coefficients of an abstract state
20-
in respect to a certain basis. These coefficients are stored in the
21-
`data` field and the basis is defined in the `basis`
22-
field.
22+
The state vector class stores an abstract state with respect
23+
to a certain basis. All subtypes must implement the `basis`
24+
method which should this basis as a subtype of `Basis`.
2325
"""
2426
abstract type StateVector{B,T} end
2527
abstract type AbstractKet{B,T} <: StateVector{B,T} end
2628
abstract type AbstractBra{B,T} <: StateVector{B,T} end
2729

2830
"""
29-
Abstract base class for all operators.
31+
Abstract type for all operators and super operators.
3032
31-
All deriving operator classes have to define the fields
32-
`basis_l` and `basis_r` defining the left and right side bases.
33+
All subtypes must implement the methods `basis_l` and
34+
`basis_r` which return subtypes of `Basis` and
35+
represent the left and right bases that the operator
36+
maps between and thus is compatible with a `Bra` defined
37+
in the left basis and a `Ket` defined in the right basis.
3338
3439
For fast time evolution also at least the function
3540
`mul!(result::Ket,op::AbstractOperator,x::Ket,alpha,beta)` should be
@@ -53,3 +58,5 @@ A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)}
5358
```
5459
"""
5560
abstract type AbstractSuperOperator{B1,B2} end
61+
62+
const AbstractQObjType = Union{<:StateVector,<:AbstractOperator}

src/bases.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
88
Total dimension of the Hilbert space.
99
"""
10-
Base.length(b::Basis) = prod(b.shape)
10+
Base.length(b::Basis) = prod(b.shape) # change to prod(size(b)) when downstream Bases are updated
1111

1212
"""
1313
GenericBasis(N)
@@ -24,7 +24,7 @@ end
2424
GenericBasis(N::Integer) = GenericBasis([N])
2525

2626
Base.:(==)(b1::GenericBasis, b2::GenericBasis) = equal_shape(b1.shape, b2.shape)
27-
27+
Base.size(b::GenericBasis) = b.shape
2828

2929
"""
3030
CompositeBasis(b1, b2...)
@@ -42,8 +42,11 @@ end
4242
CompositeBasis(bases) = CompositeBasis([length(b) for b bases], bases)
4343
CompositeBasis(bases::Basis...) = CompositeBasis((bases...,))
4444
CompositeBasis(bases::Vector) = CompositeBasis((bases...,))
45+
#bases(b::CompositeBasis) = b.bases
4546

4647
Base.:(==)(b1::T, b2::T) where T<:CompositeBasis = equal_shape(b1.shape, b2.shape)
48+
Base.size(b::CompositeBasis) = length.(b.bases)
49+
Base.getindex(b::CompositeBasis, i) = getindex(b.bases, i)
4750

4851
##
4952
# Common bases
@@ -69,6 +72,9 @@ struct FockBasis{T} <: Basis
6972
end
7073

7174
Base.:(==)(b1::FockBasis, b2::FockBasis) = (b1.N==b2.N && b1.offset==b2.offset)
75+
Base.size(b::FockBasis) = (b.N - b.offset + 1,)
76+
cutoff(b::FockBasis) = b.N
77+
offset(b::FockBasis) = b.offset
7278

7379

7480
"""
@@ -88,6 +94,7 @@ struct NLevelBasis{T} <: Basis
8894
end
8995

9096
Base.:(==)(b1::NLevelBasis, b2::NLevelBasis) = b1.N == b2.N
97+
Base.size(b::NLevelBasis) = (b.N,)
9198

9299
"""
93100
NQubitBasis(num_qubits::Int)
@@ -106,6 +113,7 @@ struct NQubitBasis{S,B} <: Basis
106113
end
107114

108115
Base.:(==)(pb1::NQubitBasis, pb2::NQubitBasis) = length(pb1.bases) == length(pb2.bases)
116+
Base.size(b::NQubitBasis) = b.shape
109117

110118
"""
111119
SpinBasis(n)
@@ -132,7 +140,8 @@ SpinBasis(spinnumber::Rational) = SpinBasis{spinnumber}(spinnumber)
132140
SpinBasis(spinnumber) = SpinBasis(convert(Rational{Int}, spinnumber))
133141

134142
Base.:(==)(b1::SpinBasis, b2::SpinBasis) = b1.spinnumber==b2.spinnumber
135-
143+
Base.size(b::SpinBasis) = (numerator(b.spinnumber*2 + 1),)
144+
spinnumber(b::SpinBasis) = b.spinnumber
136145

137146
"""
138147
SumBasis(b1, b2...)
@@ -151,3 +160,4 @@ SumBasis(bases::Basis...) = SumBasis((bases...,))
151160
Base.:(==)(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape)
152161
Base.:(==)(b1::SumBasis, b2::SumBasis) = false
153162
Base.length(b::SumBasis) = sum(b.shape)
163+
# TODO how should `.bases` be accessed? `getindex` or a `sumbases` method?

src/deprecated.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,6 @@ function check_samebases(b1, b2)
4343
end
4444
end
4545

46-
function check_multiplicable(b1, b2)
47-
if BASES_CHECK[] && !multiplicable(b1, b2)
48-
throw(IncompatibleBases())
49-
end
50-
end
51-
5246
samebases(b1::Basis, b2::Basis) = b1==b2
5347
samebases(b1::Tuple{Basis, Basis}, b2::Tuple{Basis, Basis}) = b1==b2 # for checking superoperators
5448
samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool # FIXME issue #12
@@ -68,5 +62,3 @@ function multiplicable(b1::CompositeBasis, b2::CompositeBasis)
6862
end
6963
return true
7064
end
71-
72-
multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) # FIXME issue #12

src/expect_variance.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,35 @@
33
44
If an `index` is given, it assumes that `op` is defined in the subsystem specified by this number.
55
"""
6-
function expect(indices, op::AbstractOperator{B1,B2}, state::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis}
7-
N = length(state.basis_l.shape)
8-
indices_ = complement(N, indices)
9-
expect(op, ptrace(state, indices_))
10-
end
6+
expect(indices, op::AbstractOperator, state::AbstractOperator) =
7+
expect(op, ptrace(state, complement(nsubsystems(state), indices)))
8+
9+
expect(index::Integer, op::AbstractOperator, state::AbstractOperator) = expect([index], op, state)
1110

12-
expect(index::Integer, op::AbstractOperator{B1,B2}, state::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis} = expect([index], op, state)
1311
expect(op::AbstractOperator, states::Vector) = [expect(op, state) for state=states]
12+
1413
expect(indices, op::AbstractOperator, states::Vector) = [expect(indices, op, state) for state=states]
1514

16-
expect(op::AbstractOperator{B1,B2}, state::AbstractOperator{B2,B2}) where {B1,B2} = tr(op*state)
15+
expect(op::AbstractOperator, state::AbstractOperator) =
16+
(check_multiplicable(state, state); check_multiplicable(op,state); tr(op*state))
1717

1818
"""
1919
variance(index, op, state)
2020
2121
If an `index` is given, it assumes that `op` is defined in the subsystem specified by this number
2222
"""
23-
function variance(indices, op::AbstractOperator{B,B}, state::AbstractOperator{BC,BC}) where {B,BC<:CompositeBasis}
24-
N = length(state.basis_l.shape)
25-
indices_ = complement(N, indices)
26-
variance(op, ptrace(state, indices_))
27-
end
23+
variance(indices, op::AbstractOperator, state::AbstractOperator) =
24+
variance(op, ptrace(state, complement(nsubsystems(state), indices)))
25+
26+
variance(index::Integer, op::AbstractOperator, state::AbstractOperator) = variance([index], op, state)
2827

29-
variance(index::Integer, op::AbstractOperator{B,B}, state::AbstractOperator{BC,BC}) where {B,BC<:CompositeBasis} = variance([index], op, state)
3028
variance(op::AbstractOperator, states::Vector) = [variance(op, state) for state=states]
29+
3130
variance(indices, op::AbstractOperator, states::Vector) = [variance(indices, op, state) for state=states]
3231

33-
function variance(op::AbstractOperator{B,B}, state::AbstractOperator{B,B}) where B
34-
expect(op*op, state) - expect(op, state)^2
32+
function variance(op::AbstractOperator, state::AbstractOperator)
33+
check_multiplicable(op,op)
34+
check_multiplicable(state,state)
35+
check_multiplicable(op,state)
36+
@compatiblebases expect(op*op, state) - expect(op, state)^2
3537
end

src/julia_base.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an op
1414
*(a::StateVector, b::Number) = b*a
1515
copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) # FIXME issue #12
1616
length(a::StateVector) = length(a.basis)::Int # FIXME issue #12
17-
basis(a::StateVector) = a.basis # FIXME issue #12
1817
adjoint(a::StateVector) = dagger(a)
1918

2019

@@ -33,8 +32,6 @@ Base.broadcastable(x::StateVector) = x
3332
##
3433

3534
length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int # FIXME issue #12
36-
basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) # FIXME issue #12
37-
basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) # FIXME issue #12
3835

3936
# Ensure scalar broadcasting
4037
Base.broadcastable(x::AbstractOperator) = Ref(x)

src/linalg.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,50 @@
44

55
const BASES_CHECK = Ref(true)
66

7+
"""
8+
@compatiblebases
9+
10+
Macro to skip checks for compatible bases. Useful for `*`, `expect` and similar
11+
functions.
12+
"""
13+
macro compatiblebases(ex)
14+
return quote
15+
BASES_CHECK.x = false
16+
local val = $(esc(ex))
17+
BASES_CHECK.x = true
18+
val
19+
end
20+
end
21+
22+
function check_addible(b1, b2)
23+
if BASES_CHECK[] && !addible(b1, b2)
24+
throw(IncompatibleBases())
25+
end
26+
end
27+
28+
function check_multiplicable(b1, b2)
29+
if BASES_CHECK[] && !multiplicable(b1, b2)
30+
throw(IncompatibleBases())
31+
end
32+
end
33+
34+
addible(a::AbstractQObjType, b::AbstractQObjType) = false
35+
addible(a::AbstractBra, b::AbstractBra) = (basis(a) == basis(b))
36+
addible(a::AbstractKet, b::AbstractKet) = (basis(a) == basis(b))
37+
addible(a::AbstractOperator, b::AbstractOperator) =
38+
(basis_l(a) == basis_l(b)) && (basis_r(a) == basis_r(b))
39+
40+
multiplicable(a::AbstractQObjType, b::AbstractQObjType) = false
41+
multiplicable(a::AbstractBra, b::AbstractKet) = (basis(a) == basis(b))
42+
multiplicable(a::AbstractOperator, b::AbstractKet) = (basis_r(a) == basis(b))
43+
multiplicable(a::AbstractBra, b::AbstractOperator) = (basis(a) == basis_l(b))
44+
multiplicable(a::AbstractOperator, b::AbstractOperator) = (basis_r(a) == basis_l(b))
45+
46+
basis(a::StateVector) = throw(ArgumentError("basis() is not defined for this type of state vector: $(typeof(a))."))
47+
basis_l(a::AbstractOperator) = throw(ArgumentError("basis_l() is not defined for this type of operator: $(typeof(a))."))
48+
basis_r(a::AbstractOperator) = throw(ArgumentError("basis_r() is not defined for this type of operator: $(typeof(a))."))
49+
basis(a::AbstractOperator) = (basis_l(a) == basis_r(a); basis_l(a))
50+
751
##
852
# tensor, reduce, ptrace
953
##

0 commit comments

Comments
 (0)