Skip to content

Commit ddc98fc

Browse files
Improve random number generation on mcsolve and ssesolve (#263)
1 parent dca631a commit ddc98fc

File tree

5 files changed

+101
-30
lines changed

5 files changed

+101
-30
lines changed

src/QuantumToolbox.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ import FFTW: fft, fftshift
4747
import Graphs: connected_components, DiGraph
4848
import IncompleteLU: ilu
4949
import Pkg
50-
import Random
50+
import Random: AbstractRNG, default_rng, seed!
5151
import SpecialFunctions: loggamma
5252
import StaticArraysCore: MVector
5353

src/time_evolution/mcsolve.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,20 @@ function LindbladJumpAffect!(integrator)
2929
random_n = internal_params.random_n
3030
jump_times = internal_params.jump_times
3131
jump_which = internal_params.jump_which
32+
traj_rng = internal_params.traj_rng
3233
ψ = integrator.u
3334

3435
@inbounds for i in eachindex(weights_mc)
3536
mul!(cache_mc, c_ops[i], ψ)
3637
weights_mc[i] = real(dot(cache_mc, cache_mc))
3738
end
3839
cumsum!(cumsum_weights_mc, weights_mc)
39-
collaps_idx = getindex(1:length(weights_mc), findfirst(>(rand() * sum(weights_mc)), cumsum_weights_mc))
40+
collaps_idx = getindex(1:length(weights_mc), findfirst(>(rand(traj_rng) * sum(weights_mc)), cumsum_weights_mc))
4041
mul!(cache_mc, c_ops[collaps_idx], ψ)
4142
normalize!(cache_mc)
4243
copyto!(integrator.u, cache_mc)
4344

44-
# push!(jump_times, integrator.t)
45-
# push!(jump_which, collaps_idx)
46-
random_n[] = rand()
45+
random_n[] = rand(traj_rng)
4746
jump_times[internal_params.jump_times_which_idx[]] = integrator.t
4847
jump_which[internal_params.jump_times_which_idx[]] = collaps_idx
4948
internal_params.jump_times_which_idx[] += 1
@@ -59,8 +58,11 @@ LindbladJumpDiscreteCondition(u, t, integrator) = real(dot(u, u)) < integrator.p
5958

6059
function _mcsolve_prob_func(prob, i, repeat)
6160
internal_params = prob.p
62-
seeds = internal_params.seeds
63-
!isnothing(seeds) && Random.seed!(seeds[i])
61+
62+
global_rng = internal_params.global_rng
63+
seed = internal_params.seeds[i]
64+
traj_rng = typeof(global_rng)()
65+
seed!(traj_rng, seed)
6466

6567
prm = merge(
6668
internal_params,
@@ -69,7 +71,8 @@ function _mcsolve_prob_func(prob, i, repeat)
6971
cache_mc = similar(internal_params.cache_mc),
7072
weights_mc = similar(internal_params.weights_mc),
7173
cumsum_weights_mc = similar(internal_params.weights_mc),
72-
random_n = Ref(rand()),
74+
traj_rng = traj_rng,
75+
random_n = Ref(rand(traj_rng)),
7376
progr_mc = ProgressBar(size(internal_params.expvals, 2), enable = false),
7477
jump_times_which_idx = Ref(1),
7578
jump_times = similar(internal_params.jump_times),
@@ -122,6 +125,7 @@ end
122125
e_ops::Union{Nothing,AbstractVector,Tuple}=nothing,
123126
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
124127
params::NamedTuple=NamedTuple(),
128+
rng::AbstractRNG=default_rng(),
125129
jump_callback::TJC=ContinuousLindbladJumpCallback(),
126130
kwargs...)
127131
@@ -169,7 +173,7 @@ If the environmental measurements register a quantum jump, the wave function und
169173
- `e_ops::Union{Nothing,AbstractVector,Tuple}`: List of operators for which to calculate expectation values.
170174
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
171175
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
172-
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
176+
- `rng::AbstractRNG`: Random number generator for reproducibility.
173177
- `jump_callback::LindbladJumpCallbackType`: The Jump Callback type: Discrete or Continuous.
174178
- `kwargs...`: Additional keyword arguments to pass to the solver.
175179
@@ -194,7 +198,7 @@ function mcsolveProblem(
194198
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
195199
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
196200
params::NamedTuple = NamedTuple(),
197-
seeds::Union{Nothing,Vector{Int}} = nothing,
201+
rng::AbstractRNG = default_rng(),
198202
jump_callback::TJC = ContinuousLindbladJumpCallback(),
199203
kwargs...,
200204
) where {MT1<:AbstractMatrix,TJC<:LindbladJumpCallbackType}
@@ -238,8 +242,7 @@ function mcsolveProblem(
238242
e_ops_mc = e_ops2,
239243
is_empty_e_ops_mc = is_empty_e_ops_mc,
240244
progr_mc = ProgressBar(length(t_l), enable = false),
241-
seeds = seeds,
242-
random_n = Ref(rand()),
245+
traj_rng = rng,
243246
c_ops = get_data.(c_ops),
244247
cache_mc = cache_mc,
245248
weights_mc = weights_mc,
@@ -361,7 +364,7 @@ If the environmental measurements register a quantum jump, the wave function und
361364
- `e_ops::Union{Nothing,AbstractVector,Tuple}`: List of operators for which to calculate expectation values.
362365
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
363366
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
364-
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
367+
- `rng::AbstractRNG`: Random number generator for reproducibility.
365368
- `ntraj::Int`: Number of trajectories to use.
366369
- `ensemble_method`: Ensemble method to use.
367370
- `jump_callback::LindbladJumpCallbackType`: The Jump Callback type: Discrete or Continuous.
@@ -391,10 +394,10 @@ function mcsolveEnsembleProblem(
391394
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
392395
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
393396
params::NamedTuple = NamedTuple(),
397+
rng::AbstractRNG = default_rng(),
394398
ntraj::Int = 1,
395399
ensemble_method = EnsembleThreads(),
396400
jump_callback::TJC = ContinuousLindbladJumpCallback(),
397-
seeds::Union{Nothing,Vector{Int}} = nothing,
398401
prob_func::Function = _mcsolve_prob_func,
399402
output_func::Function = _mcsolve_dispatch_output_func(ensemble_method),
400403
progress_bar::Union{Val,Bool} = Val(true),
@@ -413,6 +416,7 @@ function mcsolveEnsembleProblem(
413416

414417
# Stop the async task if an error occurs
415418
try
419+
seeds = map(i -> rand(rng, UInt64), 1:ntraj)
416420
prob_mc = mcsolveProblem(
417421
H,
418422
ψ0,
@@ -421,8 +425,8 @@ function mcsolveEnsembleProblem(
421425
alg = alg,
422426
e_ops = e_ops,
423427
H_t = H_t,
424-
params = params,
425-
seeds = seeds,
428+
params = merge(params, (global_rng = rng, seeds = seeds)),
429+
rng = rng,
426430
jump_callback = jump_callback,
427431
kwargs...,
428432
)
@@ -447,7 +451,7 @@ end
447451
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
448452
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
449453
params::NamedTuple = NamedTuple(),
450-
seeds::Union{Nothing,Vector{Int}} = nothing,
454+
rng::AbstractRNG = default_rng(),
451455
ntraj::Int = 1,
452456
ensemble_method = EnsembleThreads(),
453457
jump_callback::TJC = ContinuousLindbladJumpCallback(),
@@ -501,7 +505,7 @@ If the environmental measurements register a quantum jump, the wave function und
501505
- `e_ops::Union{Nothing,AbstractVector,Tuple}`: List of operators for which to calculate expectation values.
502506
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
503507
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
504-
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
508+
- `rng::AbstractRNG`: Random number generator for reproducibility.
505509
- `ntraj::Int`: Number of trajectories to use.
506510
- `ensemble_method`: Ensemble method to use.
507511
- `jump_callback::LindbladJumpCallbackType`: The Jump Callback type: Discrete or Continuous.
@@ -532,7 +536,7 @@ function mcsolve(
532536
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
533537
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
534538
params::NamedTuple = NamedTuple(),
535-
seeds::Union{Nothing,Vector{Int}} = nothing,
539+
rng::AbstractRNG = default_rng(),
536540
ntraj::Int = 1,
537541
ensemble_method = EnsembleThreads(),
538542
jump_callback::TJC = ContinuousLindbladJumpCallback(),
@@ -541,10 +545,6 @@ function mcsolve(
541545
progress_bar::Union{Val,Bool} = Val(true),
542546
kwargs...,
543547
) where {MT1<:AbstractMatrix,T2,TJC<:LindbladJumpCallbackType}
544-
if !isnothing(seeds) && length(seeds) != ntraj
545-
throw(ArgumentError("Length of seeds must match ntraj ($ntraj), but got $(length(seeds))"))
546-
end
547-
548548
ens_prob_mc = mcsolveEnsembleProblem(
549549
H,
550550
ψ0,
@@ -554,7 +554,7 @@ function mcsolve(
554554
e_ops = e_ops,
555555
H_t = H_t,
556556
params = params,
557-
seeds = seeds,
557+
rng = rng,
558558
ntraj = ntraj,
559559
ensemble_method = ensemble_method,
560560
jump_callback = jump_callback,

src/time_evolution/ssesolve.jl

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,17 @@ end
3232
function _ssesolve_prob_func(prob, i, repeat)
3333
internal_params = prob.p
3434

35+
global_rng = internal_params.global_rng
36+
seed = internal_params.seeds[i]
37+
traj_rng = typeof(global_rng)()
38+
seed!(traj_rng, seed)
39+
3540
noise = RealWienerProcess(
3641
prob.tspan[1],
3742
zeros(length(internal_params.sc_ops)),
3843
zeros(length(internal_params.sc_ops)),
3944
save_everystep = false,
45+
rng = traj_rng,
4046
)
4147

4248
noise_rate_prototype = similar(prob.u0, length(prob.u0), length(internal_params.sc_ops))
@@ -49,7 +55,7 @@ function _ssesolve_prob_func(prob, i, repeat)
4955
),
5056
)
5157

52-
return remake(prob, p = prm, noise = noise, noise_rate_prototype = noise_rate_prototype)
58+
return remake(prob, p = prm, noise = noise, noise_rate_prototype = noise_rate_prototype, seed = seed)
5359
end
5460

5561
# Standard output function
@@ -89,6 +95,7 @@ end
8995
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
9096
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
9197
params::NamedTuple=NamedTuple(),
98+
rng::AbstractRNG=default_rng(),
9299
kwargs...)
93100
94101
Generates the SDEProblem for the Stochastic Schrödinger time evolution of a quantum system. This is defined by the following stochastic differential equation:
@@ -122,6 +129,7 @@ Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener i
122129
- `e_ops::Union{Nothing,AbstractVector,Tuple}=nothing`: The list of operators to be evaluated during the evolution.
123130
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: The time-dependent Hamiltonian of the system. If `nothing`, the Hamiltonian is time-independent.
124131
- `params::NamedTuple`: The parameters of the system.
132+
- `rng::AbstractRNG`: The random number generator for reproducibility.
125133
- `kwargs...`: The keyword arguments passed to the `SDEProblem` constructor.
126134
127135
# Notes
@@ -145,6 +153,7 @@ function ssesolveProblem(
145153
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
146154
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
147155
params::NamedTuple = NamedTuple(),
156+
rng::AbstractRNG = default_rng(),
148157
kwargs...,
149158
) where {MT1<:AbstractMatrix,T2}
150159
H.dims != ψ0.dims && throw(DimensionMismatch("The two quantum objects are not of the same Hilbert dimension."))
@@ -200,7 +209,7 @@ function ssesolveProblem(
200209
kwargs3 = _generate_sesolve_kwargs(e_ops, Val(false), t_l, kwargs2)
201210

202211
tspan = (t_l[1], t_l[end])
203-
noise = RealWienerProcess(t_l[1], zeros(length(sc_ops)), zeros(length(sc_ops)), save_everystep = false)
212+
noise = RealWienerProcess(t_l[1], zeros(length(sc_ops)), zeros(length(sc_ops)), save_everystep = false, rng = rng)
204213
noise_rate_prototype = similar(ϕ0, length(ϕ0), length(sc_ops))
205214
return SDEProblem{true}(
206215
ssesolve_drift!,
@@ -223,6 +232,7 @@ end
223232
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
224233
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
225234
params::NamedTuple=NamedTuple(),
235+
rng::AbstractRNG=default_rng(),
226236
ntraj::Int=1,
227237
ensemble_method=EnsembleThreads(),
228238
prob_func::Function=_mcsolve_prob_func,
@@ -261,6 +271,7 @@ Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener i
261271
- `e_ops::Union{Nothing,AbstractVector,Tuple}=nothing`: The list of operators to be evaluated during the evolution.
262272
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: The time-dependent Hamiltonian of the system. If `nothing`, the Hamiltonian is time-independent.
263273
- `params::NamedTuple`: The parameters of the system.
274+
- `rng::AbstractRNG`: The random number generator for reproducibility.
264275
- `ntraj::Int`: Number of trajectories to use.
265276
- `ensemble_method`: Ensemble method to use.
266277
- `prob_func::Function`: Function to use for generating the SDEProblem.
@@ -289,6 +300,7 @@ function ssesolveEnsembleProblem(
289300
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
290301
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
291302
params::NamedTuple = NamedTuple(),
303+
rng::AbstractRNG = default_rng(),
292304
ntraj::Int = 1,
293305
ensemble_method = EnsembleThreads(),
294306
prob_func::Function = _ssesolve_prob_func,
@@ -309,10 +321,21 @@ function ssesolveEnsembleProblem(
309321

310322
# Stop the async task if an error occurs
311323
try
312-
prob_sse =
313-
ssesolveProblem(H, ψ0, tlist, sc_ops; alg = alg, e_ops = e_ops, H_t = H_t, params = params, kwargs...)
324+
seeds = map(i -> rand(rng, UInt64), 1:ntraj)
325+
prob_sse = ssesolveProblem(
326+
H,
327+
ψ0,
328+
tlist,
329+
sc_ops;
330+
alg = alg,
331+
e_ops = e_ops,
332+
H_t = H_t,
333+
params = merge(params, (global_rng = rng, seeds = seeds)),
334+
rng = rng,
335+
kwargs...,
336+
)
314337

315-
ensemble_prob = EnsembleProblem(prob_sse, prob_func = prob_func, output_func = output_func, safetycopy = false)
338+
ensemble_prob = EnsembleProblem(prob_sse, prob_func = prob_func, output_func = output_func, safetycopy = true)
316339

317340
return ensemble_prob
318341
catch e
@@ -332,6 +355,7 @@ end
332355
e_ops::Union{Nothing,AbstractVector,Tuple}=nothing,
333356
H_t::Union{Nothing,Function,TimeDependentOperatorSum}=nothing,
334357
params::NamedTuple=NamedTuple(),
358+
rng::AbstractRNG=default_rng(),
335359
ntraj::Int=1,
336360
ensemble_method=EnsembleThreads(),
337361
prob_func::Function=_ssesolve_prob_func,
@@ -373,7 +397,7 @@ Above, `C_n` is the `n`-th collapse operator and `dW_j(t)` is the real Wiener i
373397
- `e_ops::Union{Nothing,AbstractVector,Tuple}`: List of operators for which to calculate expectation values.
374398
- `H_t::Union{Nothing,Function,TimeDependentOperatorSum}`: Time-dependent part of the Hamiltonian.
375399
- `params::NamedTuple`: Dictionary of parameters to pass to the solver.
376-
- `seeds::Union{Nothing, Vector{Int}}`: List of seeds for the random number generator. Length must be equal to the number of trajectories provided.
400+
- `rng::AbstractRNG`: Random number generator for reproducibility.
377401
- `ntraj::Int`: Number of trajectories to use.
378402
- `ensemble_method`: Ensemble method to use.
379403
- `prob_func::Function`: Function to use for generating the SDEProblem.
@@ -403,6 +427,7 @@ function ssesolve(
403427
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
404428
H_t::Union{Nothing,Function,TimeDependentOperatorSum} = nothing,
405429
params::NamedTuple = NamedTuple(),
430+
rng::AbstractRNG = default_rng(),
406431
ntraj::Int = 1,
407432
ensemble_method = EnsembleThreads(),
408433
prob_func::Function = _ssesolve_prob_func,
@@ -425,6 +450,7 @@ function ssesolve(
425450
e_ops = e_ops,
426451
H_t = H_t,
427452
params = params,
453+
rng = rng,
428454
ntraj = ntraj,
429455
ensemble_method = ensemble_method,
430456
prob_func = prob_func,

test/core-test/time_evolution.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,50 @@
149149
@inferred ssesolve(H, psi0, t_l, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false))
150150
@inferred ssesolve(H, psi0, t_l, c_ops, ntraj = 500, progress_bar = Val(true))
151151
end
152+
153+
@testset "mcsolve and ssesolve reproducibility" begin
154+
N = 10
155+
a = tensor(destroy(N), qeye(2))
156+
σm = tensor(qeye(N), sigmam())
157+
σp = σm'
158+
σz = tensor(qeye(N), sigmaz())
159+
160+
ω = 1.0
161+
g = 0.1
162+
γ = 0.01
163+
nth = 0.1
164+
165+
H = ω * a' * a + ω * σz / 2 + g * (a' * σm + a * σp)
166+
c_ops = [sqrt* (1 + nth)) * a, sqrt* nth) * a', sqrt* (1 + nth)) * σm, sqrt* nth) * σp]
167+
e_ops = [a' * a, σz]
168+
169+
psi0 = tensor(basis(N, 0), basis(2, 0))
170+
tlist = range(0, 20 / γ, 1000)
171+
172+
rng = MersenneTwister(1234)
173+
sleep(0.1) # If we don't sleep, we get an error (why?)
174+
sol_mc1 = mcsolve(H, psi0, tlist, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false), rng = rng)
175+
sol_sse1 = ssesolve(H, psi0, tlist, c_ops, ntraj = 50, e_ops = e_ops, progress_bar = Val(false), rng = rng)
176+
177+
rng = MersenneTwister(1234)
178+
sleep(0.1)
179+
sol_mc2 = mcsolve(H, psi0, tlist, c_ops, ntraj = 500, e_ops = e_ops, progress_bar = Val(false), rng = rng)
180+
sol_sse2 = ssesolve(H, psi0, tlist, c_ops, ntraj = 50, e_ops = e_ops, progress_bar = Val(false), rng = rng)
181+
182+
rng = MersenneTwister(1234)
183+
sleep(0.1)
184+
sol_mc3 = mcsolve(H, psi0, tlist, c_ops, ntraj = 510, e_ops = e_ops, progress_bar = Val(false), rng = rng)
185+
186+
@test sol_mc1.expect sol_mc2.expect atol = 1e-10
187+
@test sol_mc1.expect_all sol_mc2.expect_all atol = 1e-10
188+
@test sol_mc1.jump_times sol_mc2.jump_times atol = 1e-10
189+
@test sol_mc1.jump_which sol_mc2.jump_which atol = 1e-10
190+
191+
@test sol_mc1.expect_all sol_mc3.expect_all[1:500, :, :] atol = 1e-10
192+
193+
@test sol_sse1.expect sol_sse2.expect atol = 1e-10
194+
@test sol_sse1.expect_all sol_sse2.expect_all atol = 1e-10
195+
end
152196
end
153197

154198
@testset "exceptions" begin

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Test
22
using Pkg
33
using QuantumToolbox
44
using QuantumToolbox: position, momentum
5+
using Random
56

67
const GROUP = get(ENV, "GROUP", "All")
78

0 commit comments

Comments
 (0)