Skip to content

Commit 015e34e

Browse files
committed
fix mapped solvers with NullParameters
1 parent 9309fc9 commit 015e34e

File tree

4 files changed

+32
-12
lines changed

4 files changed

+32
-12
lines changed

src/time_evolution/mesolve.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ end
240240
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
241241
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
242242
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
243-
params::Tuple = (NullParameters(),),
243+
params::Union{NullParameters,Tuple} = NullParameters(),
244244
progress_bar::Union{Val,Bool} = Val(true),
245245
kwargs...,
246246
)
@@ -293,7 +293,7 @@ function mesolve_map(
293293
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
294294
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
295295
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
296-
params::Tuple = (NullParameters(),),
296+
params::Union{NullParameters,Tuple} = NullParameters(),
297297
progress_bar::Union{Val,Bool} = Val(true),
298298
kwargs...,
299299
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}}
@@ -319,7 +319,9 @@ function mesolve_map(
319319
to_dense(T, mat2vec(ket2dm(state).data))
320320
end
321321
end
322-
iter = collect(Iterators.product(ψ0_iter, params...))
322+
iter =
323+
params isa NullParameters ? collect(Iterators.product(ψ0_iter, [params])) :
324+
collect(Iterators.product(ψ0_iter, params...))
323325
ntraj = length(iter)
324326

325327
# we disable the progress bar of the mesolveProblem because we use a global progress bar for all the trajectories
@@ -352,7 +354,11 @@ function mesolve_map(
352354
# handle solution and make it become an Array of TimeEvolutionSol
353355
sol_vec =
354356
[_gen_mesolve_solution(sol[:, i], prob.times, prob.dimensions, prob.kwargs.isoperket) for i in eachindex(sol)] # map is type unstable
355-
return reshape(sol_vec, size(iter))
357+
if params isa NullParameters # if no parameters specified, just return a Vector
358+
return sol_vec
359+
else
360+
return reshape(sol_vec, size(iter))
361+
end
356362
end
357363
mesolve_map(
358364
H::Union{AbstractQuantumObject{HOpType},Tuple},

src/time_evolution/sesolve.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ end
185185
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
186186
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
187187
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
188-
params::Tuple = (NullParameters(),),
188+
params::Union{NullParameters,Tuple} = NullParameters(),
189189
progress_bar::Union{Val,Bool} = Val(true),
190190
kwargs...,
191191
)
@@ -229,13 +229,15 @@ function sesolve_map(
229229
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
230230
ensemblealg::EnsembleAlgorithm = EnsembleThreads(),
231231
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
232-
params::Tuple = (NullParameters(),),
232+
params::Union{NullParameters,Tuple} = NullParameters(),
233233
progress_bar::Union{Val,Bool} = Val(true),
234234
kwargs...,
235235
)
236236
# mapping initial states and parameters
237237
ψ0_iter = map(get_data, ψ0)
238-
iter = collect(Iterators.product(ψ0_iter, params...))
238+
iter =
239+
params isa NullParameters ? collect(Iterators.product(ψ0_iter, [params])) :
240+
collect(Iterators.product(ψ0_iter, params...))
239241
ntraj = length(iter)
240242

241243
# we disable the progress bar of the sesolveProblem because we use a global progress bar for all the trajectories
@@ -266,7 +268,11 @@ function sesolve_map(
266268

267269
# handle solution and make it become an Array of TimeEvolutionSol
268270
sol_vec = [_gen_sesolve_solution(sol[:, i], prob.times, prob.dimensions) for i in eachindex(sol)] # map is type unstable
269-
return reshape(sol_vec, size(iter))
271+
if params isa NullParameters # if no parameters specified, just return a Vector
272+
return sol_vec
273+
else
274+
return reshape(sol_vec, size(iter))
275+
end
270276
end
271277
sesolve_map(H::Union{AbstractQuantumObject{Operator},Tuple}, ψ0::QuantumObject{Ket}, tlist::AbstractVector; kwargs...) =
272278
sesolve_map(H, [ψ0], tlist; kwargs...)

test/core-test/time_evolution.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,11 @@ end
171171
ωq_fun(p, t) = p[2]
172172
H = QobjEvo(a' * a, ωc_fun) + QobjEvo(σz / 2, ωq_fun) + g * (a' * σm + a * σm')
173173

174-
sols1 = sesolve_map(H, ψ_0_e, tlist; e_ops = e_ops, params = (ωc_list, ωq_list))
174+
sols0 = sesolve_map(TESetup.H, ψ0_list, tlist; e_ops = e_ops) # no params, but test progress_bar
175+
sols1 = sesolve_map(H, ψ_0_e, tlist; e_ops = e_ops, params = (ωc_list, ωq_list), progress_bar = Val(false))
175176
sols2 = sesolve_map(H, ψ0_list, tlist; e_ops = e_ops, params = (ωc_list, ωq_list), progress_bar = Val(false))
176-
177+
@test size(sols0) == (2,)
178+
@test sols0 isa Vector{<:TimeEvolutionSol}
177179
@test size(sols1) == (1, 3, 4)
178180
@test sols1 isa Array{<:TimeEvolutionSol}
179181
@test size(sols2) == (2, 3, 4)
@@ -193,6 +195,7 @@ end
193195
end
194196

195197
@testset "Type Inference sesolve_map" begin
198+
@inferred sesolve_map(TESetup.H, ψ0_list, tlist; e_ops = e_ops, progress_bar = Val(false)) # no params
196199
@inferred sesolve_map(H, ψ0_list, tlist; e_ops = e_ops, params = (ωc_list, ωq_list), progress_bar = Val(false))
197200
end
198201
end
@@ -324,14 +327,18 @@ end
324327
ωq_fun(p, t) = p[2]
325328
H = QobjEvo(a' * a, ωc_fun) + QobjEvo(σz / 2, ωq_fun) + g * (a' * σm + a * σm')
326329

330+
# Test with multiple initial states but no params (this also tests progress_bar)
331+
sols0 = mesolve_map(TESetup.H, ψ0_list, tlist, c_ops; e_ops = e_ops)
327332
# Test with single initial state
328-
sols1 = mesolve_map(H, ψ_0_e, tlist, c_ops; e_ops = e_ops, params = (ωc_list, ωq_list))
333+
sols1 = mesolve_map(H, ψ_0_e, tlist, c_ops; e_ops = e_ops, params = (ωc_list, ωq_list), progress_bar = Val(false))
329334
# Test with multiple initial states
330335
sols2 = mesolve_map(H, ψ0_list, tlist, c_ops; e_ops = e_ops, params = (ωc_list, ωq_list), progress_bar = Val(false))
331336

332337
# Test redirect to sesolve_map when c_ops is nothing
333338
sols3 = mesolve_map(H, ψ0_list, tlist; e_ops = e_ops, params = (ωc_list, ωq_list), progress_bar = Val(false))
334339

340+
@test size(sols0) == (2,)
341+
@test sols0 isa Vector{<:TimeEvolutionSol}
335342
@test size(sols1) == (1, 3, 4)
336343
@test sols1 isa Array{<:TimeEvolutionSol}
337344
@test size(sols2) == (2, 3, 4)
@@ -368,6 +375,7 @@ end
368375
@test sols5 isa Array{<:TimeEvolutionSol}
369376

370377
@testset "Type Inference mesolve_map" begin
378+
@inferred mesolve_map(TESetup.H, ψ0_list, tlist, c_ops; e_ops = e_ops, progress_bar = Val(false)) # no params
371379
@inferred mesolve_map(
372380
H,
373381
ψ0_list,

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Pkg
55
const GROUP_LIST = String["All", "Core", "Code-Quality", "AutoDiff_Ext", "Makie_Ext", "CUDA_Ext"]
66

77
const GROUP = get(ENV, "GROUP", "All")
8-
(GROUP in GROUP_LIST) || throw(ArgumentError("Unknown GROUP = $GROUP"))
8+
(GROUP in GROUP_LIST) || throw(ArgumentError("Unknown GROUP = $GROUP\nThe allowed groups are: $GROUP_LIST\n"))
99

1010
# Core tests
1111
if (GROUP == "All") || (GROUP == "Core")

0 commit comments

Comments
 (0)