Skip to content

Commit d789c87

Browse files
committed
add field rng for multi-trajectory solutions
1 parent 4cbc469 commit d789c87

File tree

5 files changed

+24
-17
lines changed

5 files changed

+24
-17
lines changed

src/time_evolution/mcsolve.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ function mcsolveEnsembleProblem(
256256
EnsembleProblem(prob_mc.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
257257
prob_mc.times,
258258
prob_mc.dimensions,
259-
(progr = _output_func[2], channel = _output_func[3]),
259+
(progr = _output_func[2], channel = _output_func[3], rng = rng),
260260
)
261261

262262
return ensemble_prob
@@ -412,11 +412,12 @@ function mcsolve(
412412
ens_prob_mc.times,
413413
states,
414414
expvals_all,
415+
ens_prob_mc.kwargs.rng,
415416
col_times,
416417
col_which,
417418
sol.converged,
418419
_sol_1.alg,
419-
NamedTuple(_sol_1.prob.kwargs).abstol,
420-
NamedTuple(_sol_1.prob.kwargs).reltol,
420+
_sol_1.prob.kwargs[:abstol],
421+
_sol_1.prob.kwargs[:reltol],
421422
)
422423
end

src/time_evolution/smesolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ function smesolveEnsembleProblem(
271271
EnsembleProblem(prob_sme, prob_func = _prob_func, output_func = _output_func[1], safetycopy = true),
272272
prob_sme.times,
273273
prob_sme.dimensions,
274-
merge(prob_sme.kwargs, (progr = _output_func[2], channel = _output_func[3])),
274+
merge(prob_sme.kwargs, (progr = _output_func[2], channel = _output_func[3], rng = rng)),
275275
)
276276

277277
return ensemble_prob
@@ -423,6 +423,7 @@ function smesolve(
423423
ens_prob.times,
424424
states,
425425
expvals_all,
426+
ens_prob.kwargs.rng,
426427
m_expvals, # Measurement expectation values
427428
sol.converged,
428429
_sol_1.alg,

src/time_evolution/ssesolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ function ssesolveEnsembleProblem(
264264
EnsembleProblem(prob_sme, prob_func = _prob_func, output_func = _output_func[1], safetycopy = true),
265265
prob_sme.times,
266266
prob_sme.dimensions,
267-
(progr = _output_func[2], channel = _output_func[3]),
267+
(progr = _output_func[2], channel = _output_func[3], rng = rng),
268268
)
269269

270270
return ensemble_prob
@@ -417,6 +417,7 @@ function ssesolve(
417417
ens_prob.times,
418418
states,
419419
expvals_all,
420+
ens_prob.kwargs.rng,
420421
m_expvals, # Measurement expectation values
421422
sol.converged,
422423
_sol_1.alg,

src/time_evolution/time_evolution.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ A structure storing the results and some information from solving quantum trajec
103103
- `times::AbstractVector`: The time list of the evolution.
104104
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory and each time point in `times`.
105105
- `expect::Union{AbstractArray,Nothing}`: The expectation values corresponding to each trajectory and each time point in `times`.
106+
- `rng::AbstractRNG`: Random number generator for reproducibility.
106107
- `col_times::Vector{Vector{Real}}`: The time records of every quantum jump occurred in each trajectory.
107108
- `col_which::Vector{Vector{Int}}`: The indices of which collapse operator was responsible for each quantum jump in `col_times`.
108109
- `converged::Bool`: Whether the solution is converged or not.
@@ -127,6 +128,7 @@ struct TimeEvolutionMCSol{
127128
TT<:AbstractVector{<:Real},
128129
TS<:AbstractVector,
129130
TE<:Union{AbstractArray,Nothing},
131+
TR<:AbstractRNG,
130132
TJT<:Vector{<:Vector{<:Real}},
131133
TJW<:Vector{<:Vector{<:Integer}},
132134
AlgT<:OrdinaryDiffEqAlgorithm,
@@ -137,6 +139,7 @@ struct TimeEvolutionMCSol{
137139
times::TT
138140
states::TS
139141
expect::TE
142+
rng::TR
140143
col_times::TJT
141144
col_which::TJW
142145
converged::Bool
@@ -173,6 +176,8 @@ A structure storing the results and some information from solving trajectories o
173176
- `times::AbstractVector`: The time list of the evolution.
174177
- `states::Vector{Vector{QuantumObject}}`: The list of result states in each trajectory and each time point in `times`.
175178
- `expect::Union{AbstractArray,Nothing}`: The expectation values corresponding to each trajectory and each time point in `times`.
179+
- `rng::AbstractRNG`: Random number generator for reproducibility.
180+
- `measurement::Union{AbstractArray,Nothing}`: Measurements for each trajectories and stochastic collapse operators (`sc_ops`).
176181
- `converged::Bool`: Whether the solution is converged or not.
177182
- `alg`: The algorithm which is used during the solving process.
178183
- `abstol::Real`: The absolute tolerance which is used during the solving process.
@@ -195,6 +200,7 @@ struct TimeEvolutionStochasticSol{
195200
TT<:AbstractVector{<:Real},
196201
TS<:AbstractVector,
197202
TE<:Union{AbstractArray,Nothing},
203+
TR<:AbstractRNG,
198204
TEM<:Union{AbstractArray,Nothing},
199205
AlgT<:StochasticDiffEqAlgorithm,
200206
AT<:Real,
@@ -204,6 +210,7 @@ struct TimeEvolutionStochasticSol{
204210
times::TT
205211
states::TS
206212
expect::TE
213+
rng::TR
207214
measurement::TEM
208215
converged::Bool
209216
alg::AlgT

test/core-test/time_evolution.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -751,11 +751,9 @@
751751
rng = rng,
752752
)
753753

754-
rng = MersenneTwister(1234)
755-
sol_mc2 = mcsolve(H, ψ0, tlist, c_ops, e_ops = e_ops, progress_bar = Val(false), rng = rng)
756-
rng = MersenneTwister(1234)
757-
sol_sse2 = ssesolve(H, ψ0, tlist, c_ops, ntraj = 50, e_ops = e_ops, progress_bar = Val(false), rng = rng)
758-
rng = MersenneTwister(1234)
754+
sol_mc2 = mcsolve(H, ψ0, tlist, c_ops, e_ops = e_ops, progress_bar = Val(false), rng = sol_mc1.rng)
755+
sol_sse2 =
756+
ssesolve(H, ψ0, tlist, c_ops, ntraj = 50, e_ops = e_ops, progress_bar = Val(false), rng = sol_sse1.rng)
759757
sol_sme2 = smesolve(
760758
H,
761759
ψ0,
@@ -765,14 +763,13 @@
765763
ntraj = 50,
766764
e_ops = e_ops,
767765
progress_bar = Val(false),
768-
rng = rng,
766+
rng = sol_sme1.rng,
769767
)
770768

771-
rng = MersenneTwister(1234)
772-
sol_mc3 = mcsolve(H, ψ0, tlist, c_ops, ntraj = 510, e_ops = e_ops, progress_bar = Val(false), rng = rng)
773-
rng = MersenneTwister(1234)
774-
sol_sse3 = ssesolve(H, ψ0, tlist, c_ops, ntraj = 60, e_ops = e_ops, progress_bar = Val(false), rng = rng)
775-
rng = MersenneTwister(1234)
769+
sol_mc3 =
770+
mcsolve(H, ψ0, tlist, c_ops, ntraj = 510, e_ops = e_ops, progress_bar = Val(false), rng = sol_mc1.rng)
771+
sol_sse3 =
772+
ssesolve(H, ψ0, tlist, c_ops, ntraj = 60, e_ops = e_ops, progress_bar = Val(false), rng = sol_sse1.rng)
776773
sol_sme3 = smesolve(
777774
H,
778775
ψ0,
@@ -782,7 +779,7 @@
782779
ntraj = 60,
783780
e_ops = e_ops,
784781
progress_bar = Val(false),
785-
rng = rng,
782+
rng = sol_sme1.rng,
786783
)
787784

788785
@test sol_mc1.expect sol_mc2.expect atol = 1e-10

0 commit comments

Comments
 (0)