Skip to content

Commit 25474bf

Browse files
authored
Add broadcasting support for obj .= scalar (#202)
* add `copyto!` methods for `DefaultArrayStyle{0}` * add simple `.= 1.0` tests for ops and states * have fallback error for obscure broadcasting
1 parent c9c24f9 commit 25474bf

File tree

4 files changed

+42
-0
lines changed

4 files changed

+42
-0
lines changed

src/operators_dense.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,17 @@ Base.iterate(a::DataOperator, idx) = iterate(a.data, idx)
455455
end
456456
return dest
457457
end
458+
# same implementation as Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}}) in https://github.com/JuliaLang/julia/blob/master/base/broadcast.jl
459+
@inline function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL,BR,Style<:Broadcast.DefaultArrayStyle{0},Axes,F,Args}
460+
# Typically, we must independently execute bc for every storage location in `dest`, but:
461+
# IF we're in the common no-op identity case with no nested args (like `dest .= val`),
462+
if bc.f === identity && bc.args isa Tuple{Any} && Broadcast.isflat(bc)
463+
# THEN we can just extract the argument and `fill!` the destination with it
464+
return fill!(dest, bc.args[1][])
465+
else
466+
throw(ArgumentError("no fallback implementation has been defined outside of dest .= val."))
467+
end
468+
end
458469
@inline Base.copyto!(A::DataOperator{BL,BR},B::DataOperator{BL,BR}) where {BL,BR} = (copyto!(A.data,B.data); A)
459470
@inline Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL,BR,Style<:DataOperatorStyle,Axes,F,Args} =
460471
throw(IncompatibleBases())

src/states.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,17 @@ Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::T, i) where {T<:U
227227
end
228228
return dest
229229
end
230+
# same implementation as Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}}) in https://github.com/JuliaLang/julia/blob/master/base/broadcast.jl
231+
@inline function Base.copyto!(dest::Ket{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:Broadcast.DefaultArrayStyle{0},Axes,F,Args}
232+
# Typically, we must independently execute bc for every storage location in `dest`, but:
233+
# IF we're in the common no-op identity case with no nested args (like `dest .= val`),
234+
if bc.f === identity && bc.args isa Tuple{Any} && Broadcast.isflat(bc)
235+
# THEN we can just extract the argument and `fill!` the destination with it
236+
return fill!(dest, bc.args[1][])
237+
else
238+
throw(ArgumentError("no fallback implementation has been defined outside of dest .= val."))
239+
end
240+
end
230241
@inline Base.copyto!(dest::Ket{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1,B2,Style<:KetStyle{B2},Axes,F,Args} =
231242
throw(IncompatibleBases())
232243

@@ -240,6 +251,17 @@ end
240251
end
241252
return dest
242253
end
254+
# same implementation as Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}}) in https://github.com/JuliaLang/julia/blob/master/base/broadcast.jl
255+
@inline function Base.copyto!(dest::Bra{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:Broadcast.DefaultArrayStyle{0},Axes,F,Args}
256+
# Typically, we must independently execute bc for every storage location in `dest`, but:
257+
# IF we're in the common no-op identity case with no nested args (like `dest .= val`),
258+
if bc.f === identity && bc.args isa Tuple{Any} && Broadcast.isflat(bc)
259+
# THEN we can just extract the argument and `fill!` the destination with it
260+
return fill!(dest, bc.args[1][])
261+
else
262+
throw(ArgumentError("no fallback implementation has been defined outside of dest .= val."))
263+
end
264+
end
243265
@inline Base.copyto!(dest::Bra{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1,B2,Style<:BraStyle{B2},Axes,F,Args} =
244266
throw(IncompatibleBases())
245267

test/test_operators_dense.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,9 @@ op1 .= op1_ .+ 3 * op1_
382382
bf = FockBasis(3)
383383
op3 = randoperator(bf)
384384
@test_throws QuantumOpticsBase.IncompatibleBases op1 .+ op3
385+
op .= 1.0
386+
@test op == fill!(zero(op), 1.0)
387+
@test_throws ArgumentError op .= 1.0 .+ 1.0
385388

386389
# Dimension mismatches
387390
b1, b2, b3 = NLevelBasis.((2,3,4)) # N is not a type parameter

test/test_states.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ psi_ .+= psi123
166166
bra_ = copy(bra123)
167167
bra_ .= 3*bra123
168168
@test bra_ == 3*dagger(psi123)
169+
bra .= 1.0
170+
ket .= 1.0
171+
@test bra == Bra(bf, [1.0 + 0.0im, 1.0 + 0.0im])
172+
@test ket == Ket(bf, [1.0 + 0.0im, 1.0 + 0.0im])
173+
@test_throws ArgumentError bra .= 1.0 .+ 1.0
174+
@test_throws ArgumentError ket .= 1.0 .+ 1.0
169175

170176
end # testset
171177

0 commit comments

Comments
 (0)