Skip to content

Commit d44d251

Browse files
authored
Introduce new methods of sesolve_map and mesolve_map for advanced usage (#565)
1 parent 5cd2c31 commit d44d251

File tree

3 files changed

+77
-44
lines changed

3 files changed

+77
-44
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased](https://github.com/qutip/QuantumToolbox.jl/tree/main)
99

10+
- Introduce new methods of `sesolve_map` and `mesolve_map` for advanced usage. Users can now customize their own `iter`ator structure, `prob_func` and `output_func`. ([#565])
11+
1012
## [v0.37.0]
1113
Release date: 2025-10-12
1214

@@ -336,3 +338,4 @@ Release date: 2024-11-13
336338
[#554]: https://github.com/qutip/QuantumToolbox.jl/issues/554
337339
[#555]: https://github.com/qutip/QuantumToolbox.jl/issues/555
338340
[#557]: https://github.com/qutip/QuantumToolbox.jl/issues/557
341+
[#565]: https://github.com/qutip/QuantumToolbox.jl/issues/565

src/time_evolution/mesolve.jl

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,11 @@ function mesolve_map(
319319
to_dense(T, mat2vec(ket2dm(state).data))
320320
end
321321
end
322-
iter =
323-
params isa NullParameters ? collect(Iterators.product(ψ0_iter, [params])) :
324-
collect(Iterators.product(ψ0_iter, params...))
325-
ntraj = length(iter)
322+
if params isa NullParameters
323+
iter = collect(Iterators.product(ψ0_iter, [params])) |> vec # convert nx1 Matrix into Vector
324+
else
325+
iter = collect(Iterators.product(ψ0_iter, params...))
326+
end
326327

327328
# we disable the progress bar of the mesolveProblem because we use a global progress bar for all the trajectories
328329
prob = mesolveProblem(
@@ -336,35 +337,49 @@ function mesolve_map(
336337
kwargs...,
337338
)
338339

339-
# generate and solve ensemble problem
340-
_output_func = _ensemble_dispatch_output_func(ensemblealg, progress_bar, ntraj, _standard_output_func) # setup global progress bar
340+
return mesolve_map(prob, iter, alg, ensemblealg; progress_bar = progress_bar)
341+
end
342+
mesolve_map(
343+
H::Union{AbstractQuantumObject{HOpType},Tuple},
344+
ψ0::QuantumObject{StateOpType},
345+
tlist::AbstractVector,
346+
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
347+
kwargs...,
348+
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}} =
349+
mesolve_map(H, [ψ0], tlist, c_ops; kwargs...)
350+
351+
# this method is for advanced usage
352+
# User can define their own iterator structure, prob_func and output_func
353+
# - `prob_func`: Function to use for generating the ODEProblem.
354+
# - `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object.
355+
#
356+
# Return: An array of TimeEvolutionSol objects with the size same as the given iter.
357+
function mesolve_map(
358+
prob::TimeEvolutionProblem{<:ODEProblem},
359+
iter::AbstractArray,
360+
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
361+
ensemblealg::EnsembleAlgorithm = EnsembleThreads();
362+
prob_func::Union{Function,Nothing} = nothing,
363+
output_func::Union{Tuple,Nothing} = nothing,
364+
progress_bar::Union{Val,Bool} = Val(true),
365+
)
366+
# generate ensemble problem
367+
ntraj = length(iter)
368+
_prob_func = isnothing(prob_func) ? (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter) : prob_func
369+
_output_func =
370+
isnothing(output_func) ?
371+
_ensemble_dispatch_output_func(ensemblealg, progress_bar, ntraj, _standard_output_func) : output_func
341372
ens_prob = TimeEvolutionProblem(
342-
EnsembleProblem(
343-
prob.prob,
344-
prob_func = (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter),
345-
output_func = _output_func[1],
346-
safetycopy = false,
347-
),
373+
EnsembleProblem(prob.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
348374
prob.times,
349375
prob.dimensions,
350376
(progr = _output_func[2], channel = _output_func[3], isoperket = prob.kwargs.isoperket),
351377
)
378+
352379
sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)
353380

354381
# handle solution and make it become an Array of TimeEvolutionSol
355382
sol_vec =
356383
[_gen_mesolve_solution(sol[:, i], prob.times, prob.dimensions, prob.kwargs.isoperket) for i in eachindex(sol)] # map is type unstable
357-
if params isa NullParameters # if no parameters specified, just return a Vector
358-
return sol_vec
359-
else
360-
return reshape(sol_vec, size(iter))
361-
end
384+
return reshape(sol_vec, size(iter))
362385
end
363-
mesolve_map(
364-
H::Union{AbstractQuantumObject{HOpType},Tuple},
365-
ψ0::QuantumObject{StateOpType},
366-
tlist::AbstractVector,
367-
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
368-
kwargs...,
369-
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}} =
370-
mesolve_map(H, [ψ0], tlist, c_ops; kwargs...)

src/time_evolution/sesolve.jl

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,11 @@ function sesolve_map(
235235
)
236236
# mapping initial states and parameters
237237
ψ0_iter = map(get_data, ψ0)
238-
iter =
239-
params isa NullParameters ? collect(Iterators.product(ψ0_iter, [params])) :
240-
collect(Iterators.product(ψ0_iter, params...))
241-
ntraj = length(iter)
238+
if params isa NullParameters
239+
iter = collect(Iterators.product(ψ0_iter, [params])) |> vec # convert nx1 Matrix into Vector
240+
else
241+
iter = collect(Iterators.product(ψ0_iter, params...))
242+
end
242243

243244
# we disable the progress bar of the sesolveProblem because we use a global progress bar for all the trajectories
244245
prob = sesolveProblem(
@@ -251,28 +252,42 @@ function sesolve_map(
251252
kwargs...,
252253
)
253254

254-
# generate and solve ensemble problem
255-
_output_func = _ensemble_dispatch_output_func(ensemblealg, progress_bar, ntraj, _standard_output_func) # setup global progress bar
255+
return sesolve_map(prob, iter, alg, ensemblealg; progress_bar = progress_bar)
256+
end
257+
sesolve_map(H::Union{AbstractQuantumObject{Operator},Tuple}, ψ0::QuantumObject{Ket}, tlist::AbstractVector; kwargs...) =
258+
sesolve_map(H, [ψ0], tlist; kwargs...)
259+
260+
# this method is for advanced usage
261+
# User can define their own iterator structure, prob_func and output_func
262+
# - `prob_func`: Function to use for generating the ODEProblem.
263+
# - `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object.
264+
#
265+
# Return: An array of TimeEvolutionSol objects with the size same as the given iter.
266+
function sesolve_map(
267+
prob::TimeEvolutionProblem{<:ODEProblem},
268+
iter::AbstractArray,
269+
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
270+
ensemblealg::EnsembleAlgorithm = EnsembleThreads();
271+
prob_func::Union{Function,Nothing} = nothing,
272+
output_func::Union{Tuple,Nothing} = nothing,
273+
progress_bar::Union{Val,Bool} = Val(true),
274+
)
275+
# generate ensemble problem
276+
ntraj = length(iter)
277+
_prob_func = isnothing(prob_func) ? (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter) : prob_func
278+
_output_func =
279+
isnothing(output_func) ?
280+
_ensemble_dispatch_output_func(ensemblealg, progress_bar, ntraj, _standard_output_func) : output_func
256281
ens_prob = TimeEvolutionProblem(
257-
EnsembleProblem(
258-
prob.prob,
259-
prob_func = (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter),
260-
output_func = _output_func[1],
261-
safetycopy = false,
262-
),
282+
EnsembleProblem(prob.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
263283
prob.times,
264284
prob.dimensions,
265285
(progr = _output_func[2], channel = _output_func[3]),
266286
)
287+
267288
sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)
268289

269290
# handle solution and make it become an Array of TimeEvolutionSol
270291
sol_vec = [_gen_sesolve_solution(sol[:, i], prob.times, prob.dimensions) for i in eachindex(sol)] # map is type unstable
271-
if params isa NullParameters # if no parameters specified, just return a Vector
272-
return sol_vec
273-
else
274-
return reshape(sol_vec, size(iter))
275-
end
292+
return reshape(sol_vec, size(iter))
276293
end
277-
sesolve_map(H::Union{AbstractQuantumObject{Operator},Tuple}, ψ0::QuantumObject{Ket}, tlist::AbstractVector; kwargs...) =
278-
sesolve_map(H, [ψ0], tlist; kwargs...)

0 commit comments

Comments
 (0)