Skip to content

Commit 73bac88

Browse files
apkilleKrastanov
andauthored
Broadcasting Kets and Operators (#172)
Co-authored-by: Stefan Krastanov <[email protected]>
1 parent 06c2845 commit 73bac88

15 files changed

+153
-71
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
include:
2626
- os: ubuntu-latest
2727
arch: x64
28-
version: '1.6'
28+
version: '1.10'
2929
steps:
3030
- uses: actions/checkout@v4
3131
- uses: julia-actions/setup-julia@v2

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
QuantumInterface = "5717a53b-5d69-4fa3-b976-0bf2f97ca1e5"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
RandomMatrices = "2576dda1-a324-5b11-aa66-c48ed7e3c618"
16+
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1617
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1718
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
1819
UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6"
@@ -28,7 +29,8 @@ LinearAlgebra = "1"
2829
QuantumInterface = "0.3.3"
2930
Random = "1"
3031
RandomMatrices = "0.5"
32+
RecursiveArrayTools = "3"
3133
SparseArrays = "1"
3234
Strided = "1, 2"
3335
UnsafeArrays = "1"
34-
julia = "1.6"
36+
julia = "1.10"

src/QuantumOpticsBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module QuantumOpticsBase
22

33
using SparseArrays, LinearAlgebra, LRUCache, Strided, UnsafeArrays, FillArrays
44
import LinearAlgebra: mul!, rmul!
5+
import RecursiveArrayTools
56

67
import QuantumInterface: dagger, directsum, , dm, embed, nsubsystems, expect, identityoperator, identitysuperoperator,
78
permutesystems, projector, ptrace, reduced, tensor, , variance, apply!, basis, AbstractSuperOperator

src/operators_dense.jl

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -423,37 +423,50 @@ struct OperatorStyle{BL,BR} <: DataOperatorStyle{BL,BR} end
423423
Broadcast.BroadcastStyle(::Type{<:Operator{BL,BR}}) where {BL,BR} = OperatorStyle{BL,BR}()
424424
Broadcast.BroadcastStyle(::OperatorStyle{B1,B2}, ::OperatorStyle{B3,B4}) where {B1,B2,B3,B4} = throw(IncompatibleBases())
425425

426+
# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`)
427+
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {Bl<:Basis, Br<:Basis, T<:OperatorStyle{Bl,Br}} = T()
428+
426429
# Out-of-place broadcasting
427430
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL,BR,Style<:OperatorStyle{BL,BR},Axes,F,Args<:Tuple}
428431
bcf = Broadcast.flatten(bc)
429432
bl,br = find_basis(bcf.args)
430-
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
431-
return Operator{BL,BR}(bl, br, copy(bc_))
433+
T = find_dType(bcf)
434+
data = zeros(T, length(bl), length(br))
435+
@inbounds @simd for I in eachindex(bcf)
436+
data[I] = bcf[I]
437+
end
438+
return Operator{BL,BR}(bl, br, data)
432439
end
433-
find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r)
434440

435-
const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
436-
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:DataOperator}}, axes)
437-
args_ = Tuple(a.data for a=args)
438-
return Broadcast.Broadcasted(f, args_, axes)
439-
end
441+
find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r)
442+
find_dType(a::DataOperator, rest) = eltype(a)
443+
@inline Base.getindex(a::DataOperator, idx) = getindex(a.data, idx)
444+
Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::DataOperator, i) = x.data[i]
445+
Base.iterate(a::DataOperator) = iterate(a.data)
446+
Base.iterate(a::DataOperator, idx) = iterate(a.data, idx)
440447

441448
# In-place broadcasting
442449
@inline function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL,BR,Style<:DataOperatorStyle{BL,BR},Axes,F,Args}
443450
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
444-
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
445-
if bc.f === identity && isa(bc.args, Tuple{<:DataOperator{BL,BR}}) # only a single input argument to broadcast!
446-
A = bc.args[1]
447-
if axes(dest) == axes(A)
448-
return copyto!(dest, A)
449-
end
451+
bc′ = Base.Broadcast.preprocess(dest, bc)
452+
dest′ = dest.data
453+
@inbounds @simd for I in eachindex(bc′)
454+
dest′[I] = bc′[I]
450455
end
451-
# Get the underlying data fields of operators and broadcast them as arrays
452-
bcf = Broadcast.flatten(bc)
453-
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
454-
copyto!(dest.data, bc_)
455456
return dest
456457
end
457458
@inline Base.copyto!(A::DataOperator{BL,BR},B::DataOperator{BL,BR}) where {BL,BR} = (copyto!(A.data,B.data); A)
458459
@inline Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL,BR,Style<:DataOperatorStyle,Axes,F,Args} =
459460
throw(IncompatibleBases())
461+
462+
# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl
463+
Base.eltype(::Type{Operator{Bl,Br,A}}) where {Bl,Br,N,A<:AbstractMatrix{N}} = N # ODE init
464+
Base.any(f::Function, x::Operator; kwargs...) = any(f, x.data; kwargs...) # ODE nan checks
465+
Base.all(f::Function, x::Operator; kwargs...) = all(f, x.data; kwargs...)
466+
Base.fill!(x::Operator, a) = typeof(x)(x.basis_l, x.basis_r, fill!(x.data, a))
467+
Base.ndims(x::Type{Operator{Bl,Br,A}}) where {Bl,Br,N,A<:AbstractMatrix{N}} = ndims(A)
468+
Base.similar(x::Operator, t) = typeof(x)(x.basis_l, x.basis_r, copy(x.data))
469+
RecursiveArrayTools.recursivecopy!(dest::Operator{Bl,Br,A},src::Operator{Bl,Br,A}) where {Bl,Br,A} = copyto!(dest,src) # ODE in-place equations
470+
RecursiveArrayTools.recursivecopy(x::Operator) = copy(x)
471+
RecursiveArrayTools.recursivecopy(x::AbstractArray{T}) where {T<:Operator} = copy(x)
472+
RecursiveArrayTools.recursivefill!(x::Operator, a) = fill!(x, a)

src/states.jl

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -180,52 +180,51 @@ Broadcast.BroadcastStyle(::Type{<:Bra{B}}) where {B} = BraStyle{B}()
180180
Broadcast.BroadcastStyle(::KetStyle{B1}, ::KetStyle{B2}) where {B1,B2} = throw(IncompatibleBases())
181181
Broadcast.BroadcastStyle(::BraStyle{B1}, ::BraStyle{B2}) where {B1,B2} = throw(IncompatibleBases())
182182

183+
# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`)
184+
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:KetStyle{B}} = T()
185+
Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:BraStyle{B}} = T()
186+
183187
# Out-of-place broadcasting
184188
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:KetStyle{B},Axes,F,Args<:Tuple}
185189
bcf = Broadcast.flatten(bc)
186-
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
187190
b = find_basis(bcf)
188-
return Ket{B}(b, copy(bc_))
191+
T = find_dType(bcf)
192+
data = zeros(T, length(b))
193+
@inbounds @simd for I in eachindex(bcf)
194+
data[I] = bcf[I]
195+
end
196+
return Ket{B}(b, data)
189197
end
190198
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:BraStyle{B},Axes,F,Args<:Tuple}
191199
bcf = Broadcast.flatten(bc)
192-
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
193200
b = find_basis(bcf)
194-
return Bra{B}(b, copy(bc_))
195-
end
196-
find_basis(bc::Broadcast.Broadcasted) = find_basis(bc.args)
197-
find_basis(args::Tuple) = find_basis(find_basis(args[1]), Base.tail(args))
198-
find_basis(x) = x
199-
find_basis(a::StateVector, rest) = a.basis
200-
find_basis(::Any, rest) = find_basis(rest)
201-
202-
const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
203-
function Broadcasted_restrict_f(f::BasicMathFunc, args::NTuple{N,<:T}, axes) where {T<:StateVector,N}
204-
args_ = Tuple(a.data for a=args)
205-
return Broadcast.Broadcasted(f, args_, axes)
206-
end
207-
function Broadcasted_restrict_f(f, args::Tuple, axes)
208-
error("Cannot broadcast function `$f` on $(typeof(args))")
201+
T = find_dType(bcf)
202+
data = zeros(T, length(b))
203+
@inbounds @simd for I in eachindex(bcf)
204+
data[I] = bcf[I]
205+
end
206+
return Bra{B}(b, data)
209207
end
210-
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{}, axes) # Defined to avoid method ambiguities
211-
error("Cannot broadcast function `$f` on an empty set of arguments")
208+
for f [:find_basis,:find_dType]
209+
@eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args)
210+
@eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args))
211+
@eval ($f)(x) = x
212+
@eval ($f)(::Any, rest) = ($f)(rest)
212213
end
213214

215+
find_basis(x::T, rest) where {T<:Union{Ket, Bra}} = x.basis
216+
find_dType(x::T, rest) where {T<:Union{Ket, Bra}} = eltype(x)
217+
@inline Base.getindex(x::T, idx) where {T<:Union{Ket, Bra}} = getindex(x.data, idx)
218+
Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::T, i) where {T<:Union{Ket, Bra}} = x.data[i]
219+
214220
# In-place broadcasting for Kets
215221
@inline function Base.copyto!(dest::Ket{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:KetStyle{B},Axes,F,Args}
216222
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
217-
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
218-
if bc.f === identity && isa(bc.args, Tuple{<:Ket{B}}) # only a single input argument to broadcast!
219-
A = bc.args[1]
220-
if axes(dest) == axes(A)
221-
return copyto!(dest, A)
222-
end
223+
bc′ = Base.Broadcast.preprocess(dest, bc)
224+
dest′ = dest.data
225+
@inbounds @simd for I in eachindex(bc′)
226+
dest′[I] = bc′[I]
223227
end
224-
# Get the underlying data fields of kets and broadcast them as arrays
225-
bcf = Broadcast.flatten(bc)
226-
args_ = Tuple(a.data for a=bcf.args)
227-
bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf))
228-
copyto!(dest.data, bc_)
229228
return dest
230229
end
231230
@inline Base.copyto!(dest::Ket{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1,B2,Style<:KetStyle{B2},Axes,F,Args} =
@@ -234,20 +233,27 @@ end
234233
# In-place broadcasting for Bras
235234
@inline function Base.copyto!(dest::Bra{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:BraStyle{B},Axes,F,Args}
236235
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
237-
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
238-
if bc.f === identity && isa(bc.args, Tuple{<:Bra{B}}) # only a single input argument to broadcast!
239-
A = bc.args[1]
240-
if axes(dest) == axes(A)
241-
return copyto!(dest, A)
242-
end
236+
bc′ = Base.Broadcast.preprocess(dest, bc)
237+
dest′ = dest.data
238+
@inbounds @simd for I in eachindex(bc′)
239+
dest′[I] = bc′[I]
243240
end
244-
# Get the underlying data fields of bras and broadcast them as arrays
245-
bcf = Broadcast.flatten(bc)
246-
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
247-
copyto!(dest.data, bc_)
248241
return dest
249242
end
250243
@inline Base.copyto!(dest::Bra{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1,B2,Style<:BraStyle{B2},Axes,F,Args} =
251244
throw(IncompatibleBases())
252245

253-
@inline Base.copyto!(A::T,B::T) where T<:Union{Ket, Bra} = (copyto!(A.data,B.data); A) # Can not use T<:QuantumInterface.StateVector, because StateVector does not imply the existence of a data property
246+
@inline Base.copyto!(dest::T,src::T) where {T<:Union{Ket, Bra}} = (copyto!(dest.data,src.data); dest) # Can not use T<:QuantumInterface.StateVector, because StateVector does not imply the existence of a data property
247+
248+
# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl
249+
Base.eltype(::Type{Ket{B,A}}) where {B,N,A<:AbstractVector{N}} = N # ODE init
250+
Base.eltype(::Type{Bra{B,A}}) where {B,N,A<:AbstractVector{N}} = N
251+
Base.any(f::Function, x::T; kwargs...) where {T<:Union{Ket, Bra}} = any(f, x.data; kwargs...) # ODE nan checks
252+
Base.all(f::Function, x::T; kwargs...) where {T<:Union{Ket, Bra}} = all(f, x.data; kwargs...)
253+
Base.fill!(x::T, a) where {T<:Union{Ket, Bra}} = typeof(x)(x.basis, fill!(x.data, a))
254+
Base.similar(x::T, t) where {T<:Union{Ket, Bra}} = typeof(x)(x.basis, similar(x.data))
255+
RecursiveArrayTools.recursivecopy!(dest::Ket{B,A},src::Ket{B,A}) where {B,A} = copyto!(dest, src) # ODE in-place equations
256+
RecursiveArrayTools.recursivecopy!(dest::Bra{B,A},src::Bra{B,A}) where {B,A} = copyto!(dest, src)
257+
RecursiveArrayTools.recursivecopy(x::T) where {T<:Union{Ket, Bra}} = copy(x)
258+
RecursiveArrayTools.recursivecopy(x::AbstractArray{T}) where {T<:Union{Ket, Bra}} = copy(x)
259+
RecursiveArrayTools.recursivefill!(x::T, a) where {T<:Union{Ket, Bra}} = fill!(x, a)

src/superoperators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ end
287287
# end
288288
find_basis(a::SuperOperator, rest) = (a.basis_l, a.basis_r)
289289

290-
const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
290+
const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)}
291291
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:SuperOperator}}, axes)
292292
args_ = Tuple(a.data for a=args)
293293
return Broadcast.Broadcasted(f, args_, axes)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
77
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
88
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1011
QuantumInterface = "5717a53b-5d69-4fa3-b976-0bf2f97ca1e5"
12+
QuantumOptics = "6e0679c1-51ea-5a7c-ac74-d61b76210b0c"
1113
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1214
RandomMatrices = "2576dda1-a324-5b11-aa66-c48ed7e3c618"
1315
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ names = [
2222
"test_subspace.jl",
2323
"test_state_definitions.jl",
2424

25+
"test_sciml_broadcast_interfaces.jl",
26+
2527
"test_transformations.jl",
2628

2729
"test_metrics.jl",

test/test_abstractdata.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,6 @@ op1 .= op1_ .+ 3 * op1_
340340
bf = FockBasis(3)
341341
op3 = randtestoperator(bf)
342342
@test_throws QuantumOpticsBase.IncompatibleBases op1 .+ op3
343-
@test_throws ErrorException cos.(op1)
344343

345344
####################
346345
# Test lazy tensor #

test/test_jet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ using LinearAlgebra, LRUCache, Strided, StridedViews, Dates, SparseArrays, Rando
3535
AnyFrameModule(RandomMatrices))
3636
)
3737
@show rep
38-
@test length(JET.get_reports(rep)) <= 24
38+
@test length(JET.get_reports(rep)) <= 28
3939
@test_broken length(JET.get_reports(rep)) == 0
4040
end
4141
end # testset

0 commit comments

Comments
 (0)