11module semiclassical
22
33using QuantumOpticsBase
4- import Base: ==
4+ import QuantumOpticsBase: IncompatibleBases
5+ import Base: == , isapprox, + , - , * , /
56import .. timeevolution: integrate, recast!, jump, integrate_mcwf, jump_callback,
67 JumpRNGState, threshold, roll!, as_vector, QO_CHECKS
78import LinearAlgebra: normalize, normalize!
9+ import RecursiveArrayTools
810
911using Random, LinearAlgebra
1012import OrdinaryDiffEq
@@ -31,26 +33,104 @@ mutable struct State{B,T,C}
3133 new {B,T,C} (quantum, classical)
3234 end
3335end
34-
35- Base. length (state:: State ) = length (state. quantum) + length (state. classical)
36- Base. copy (state:: State ) = State (copy (state. quantum), copy (state. classical))
37- Base. eltype (state:: State ) = promote_type (eltype (state. quantum),eltype (state. classical))
38- normalize! (state:: State ) = (normalize! (state. quantum); state)
39- normalize (state:: State ) = State (normalize (state. quantum),copy (state. classical))
40-
41- function == (a:: State , b:: State )
42- QuantumOpticsBase. samebases (a. quantum, b. quantum) &&
43- length (a. classical)== length (b. classical) &&
44- (a. classical== b. classical) &&
45- (a. quantum== b. quantum)
46- end
36+ State {B} (q:: T , c:: C ) where {B,T<: QuantumState{B} ,C} = State (q,c)
37+
38+ # Standard interfaces
39+ Base. zero (x:: State ) = State (zero (x. quantum), zero (x. classical))
40+ Base. length (x:: State ) = length (x. quantum) + length (x. classical)
41+ Base. axes (x:: State ) = (Base. OneTo (length (x)),)
42+ Base. size (x:: State ) = size (x. quantum)
43+ Base. ndims (x:: Type{<:State{B,T,C}} ) where {B,T<: QuantumState{B} ,C} = ndims (T)
44+ Base. copy (x:: State ) = State (copy (x. quantum), copy (x. classical))
45+ Base. copyto! (x:: State , y:: State ) = (copyto! (x. quantum, y. quantum); copyto! (x. classical, y. classical); x)
46+ Base. fill! (x:: State , a) = (fill! (x. quantum, a), fill! (x. classical, a))
47+ Base. eltype (x:: State ) = promote_type (eltype (x. quantum),eltype (x. classical))
48+ Base. eltype (x:: Type{<:State{B,T,C}} ) where {B,T<: QuantumState{B} ,C} = promote_type (eltype (T), eltype (C))
49+ Base. similar (x:: State , :: Type{T} = eltype (x)) where {T} = State (similar (x. quantum, T), similar (x. classical, T))
50+ Base. getindex (x:: State , idx) = idx <= length (x. quantum) ? getindex (x. quantum, idx) : getindex (x. classical, idx- length (x. quantum))
51+
52+ normalize! (x:: State ) = (normalize! (x. quantum); x)
53+ normalize (x:: State ) = State (normalize (x. quantum),copy (x. classical))
54+ LinearAlgebra. norm (x:: State ) = LinearAlgebra. norm (x. quantum)
55+
56+ == (x:: State{B} , y:: State{B} ) where {B} = (x. classical== y. classical) && (x. quantum== y. quantum)
57+ == (x:: State , y:: State ) = false
58+
59+ isapprox (x:: State{B} , y:: State{B} ; kwargs... ) where {B} = isapprox (x. quantum,y. quantum; kwargs... ) && isapprox (x. classical,y. classical; kwargs... )
60+ isapprox (x:: State , y:: State ; kwargs... ) = false
4761
4862QuantumOpticsBase. expect (op, state:: State ) = expect (op, state. quantum)
4963QuantumOpticsBase. variance (op, state:: State ) = variance (op, state. quantum)
5064QuantumOpticsBase. ptrace (state:: State , indices) = State (ptrace (state. quantum, indices), state. classical)
51-
5265QuantumOpticsBase. dm (x:: State ) = State (dm (x. quantum), x. classical)
5366
67+ Base. broadcastable (x:: State ) = x
68+
69+ # Custom broadcasting style
70+ struct StateStyle{B} <: Broadcast.BroadcastStyle end
71+
72+ # Style precedence rules
73+ Broadcast. BroadcastStyle (:: Type{<:State{B}} ) where {B} = StateStyle {B} ()
74+ Broadcast. BroadcastStyle (:: StateStyle{B1} , :: StateStyle{B2} ) where {B1,B2} = throw (IncompatibleBases ())
75+ Broadcast. BroadcastStyle (:: StateStyle{B} , :: Broadcast.DefaultArrayStyle{0} ) where {B} = StateStyle {B} ()
76+ Broadcast. BroadcastStyle (:: Broadcast.DefaultArrayStyle{0} , :: StateStyle{B} ) where {B} = StateStyle {B} ()
77+
78+ # Out-of-place broadcasting
79+ @inline function Base. copy (bc:: Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args} ) where {B,Axes,F,Args<: Tuple }
80+ bcf = Broadcast. flatten (bc)
81+ # extract quantum object from broadcast container
82+ qobj = find_quantum (bcf)
83+ data_q = zeros (eltype (qobj), size (qobj)... )
84+ Nq = length (qobj)
85+ # allocate quantum data from broadcast container
86+ @inbounds @simd for I in 1 : Nq
87+ data_q[I] = bcf[I]
88+ end
89+ # extract classical object from broadcast container
90+ cobj = find_classical (bcf)
91+ data_c = zeros (eltype (cobj), length (cobj))
92+ Nc = length (cobj)
93+ # allocate classical data from broadcast container
94+ @inbounds @simd for I in 1 : Nc
95+ data_c[I] = bcf[I+ Nq]
96+ end
97+ type = eval (nameof (typeof (qobj)))
98+ return State {B} (type (basis (qobj), data_q), data_c)
99+ end
100+
101+ for f ∈ [:find_quantum , :find_classical ]
102+ @eval ($ f)(bc:: Broadcast.Broadcasted ) = ($ f)(bc. args)
103+ @eval ($ f)(args:: Tuple ) = ($ f)(($ f)(args[1 ]), Base. tail (args))
104+ @eval ($ f)(x) = x
105+ @eval ($ f)(:: Any , rest) = ($ f)(rest)
106+ end
107+ find_quantum (x:: State , rest) = x. quantum
108+ find_classical (x:: State , rest) = x. classical
109+
110+ # In-place broadcasting
111+ @inline function Base. copyto! (dest:: State{B} , bc:: Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args} ) where {B,Axes,F,Args<: Tuple }
112+ axes (dest) == axes (bc) || throwdm (axes (dest), axes (bc))
113+ bc′ = Base. Broadcast. preprocess (dest, bc)
114+ # write broadcasted quantum data to dest
115+ qobj = dest. quantum
116+ @inbounds @simd for I in 1 : length (qobj)
117+ qobj. data[I] = bc′[I]
118+ end
119+ # write broadcasted classical data to dest
120+ cobj = dest. classical
121+ @inbounds @simd for I in 1 : length (cobj)
122+ cobj[I] = bc′[I+ length (qobj)]
123+ end
124+ return dest
125+ end
126+ @inline Base. copyto! (dest:: State{B1} , bc:: Broadcast.Broadcasted{<:StateStyle{B2},Axes,F,Args} ) where {B1,B2,Axes,F,Args<: Tuple } =
127+ throw (IncompatibleBases ())
128+
129+ Base. @propagate_inbounds Base. Broadcast. _broadcast_getindex (x:: State , i) = Base. getindex (x, i)
130+ RecursiveArrayTools. recursive_unitless_bottom_eltype (x:: State ) = eltype (x)
131+ RecursiveArrayTools. recursivecopy! (dest:: State , src:: State ) = copyto! (dest, src)
132+ RecursiveArrayTools. recursivecopy (x:: State ) = copy (x)
133+ RecursiveArrayTools. recursivefill! (x:: State , a) = fill! (x, a)
54134
55135"""
56136 semiclassical.schroedinger_dynamic(tspan, state0, fquantum, fclassical[; fout, ...])
0 commit comments