Skip to content

Commit a584b30

Browse files
committed
improve code coverage
1 parent d573435 commit a584b30

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

src/time_evolution/time_evolution.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,20 @@ _store_multitraj_expect(expvals::Nothing, keep_runs_results) = nothing
287287
Return the trajectory-wise standard deviation of the expectation values at each time point.
288288
"""
289289
function std_expect(sol::TimeEvolutionMultiTrajSol{TS,Array{T,3}}) where {TS,T<:Number}
290+
# the following standard deviation (std) is defined as the square-root of variance instead of pseudo-variance
291+
# i.e., it is equivalent to (even for complex expectation values):
292+
# dropdims(
293+
# sqrt.(mean(abs2.(sol.expect), dims = 2) .- abs2.(mean(sol.expect, dims = 2))),
294+
# dims = 2
295+
# )
296+
# [this should be included in the runtest]
290297
return dropdims(std(sol.expect, corrected = false, dims = 2), dims = 2)
291298
end
299+
std_expect(::TimeEvolutionMultiTrajSol{TS,Matrix{T}}) where {TS,T<:Number} = throw(
300+
ArgumentError(
301+
"Can not compute the standard deviation without the expectation values of each trajectory. Try to specify keyword argument `keep_runs_results=Val(true)` to the solver.",
302+
),
303+
)
292304
std_expect(::TimeEvolutionMultiTrajSol{TS,Nothing}) where {TS} = nothing
293305

294306
#######################################

test/core-test/time_evolution.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ end
248248

249249
@testitem "mcsolve" setup=[TESetup] begin
250250
using SciMLOperators
251+
using Statistics
251252

252253
# Get parameters from TESetup to simplify the code
253254
H = TESetup.H
@@ -270,6 +271,7 @@ end
270271
progress_bar = Val(false),
271272
jump_callback = DiscreteLindbladJumpCallback(),
272273
)
274+
sol_mc3 = mcsolve(H, ψ0, tlist, c_ops, e_ops = e_ops, progress_bar = Val(false), keep_runs_results = Val(true))
273275
sol_mc_states =
274276
mcsolve(H, ψ0, tlist, c_ops, saveat = saveat, progress_bar = Val(false), keep_runs_results = Val(true))
275277
sol_mc_states2 = mcsolve(
@@ -291,12 +293,17 @@ end
291293
@test prob_mc.prob.f.f isa MatrixOperator
292294
@test sum(abs, sol_mc.expect .- sol_me.expect) / length(tlist) < 0.1
293295
@test sum(abs, sol_mc2.expect .- sol_me.expect) / length(tlist) < 0.1
296+
@test sum(abs, average_expect(sol_mc3) .- sol_me.expect) / length(tlist) < 0.1
294297
@test sum(abs, expect_mc_states_mean .- vec(sol_me.expect[1, saveat_idxs])) / length(tlist) < 0.1
295298
@test sum(abs, expect_mc_states_mean2 .- vec(sol_me.expect[1, saveat_idxs])) / length(tlist) < 0.1
296299
@test length(sol_mc.times) == length(tlist)
297300
@test length(sol_mc.times_states) == 1
298301
@test size(sol_mc.expect) == (length(e_ops), length(tlist))
299302
@test size(sol_mc.states) == (1,)
303+
@test length(sol_mc3.times) == length(tlist)
304+
@test length(sol_mc3.times_states) == 1
305+
@test size(sol_mc3.expect) == (length(e_ops), 500, length(tlist)) # ntraj = 500
306+
@test size(sol_mc3.states) == (500, 1) # ntraj = 500
300307
@test length(sol_mc_states.times) == length(tlist)
301308
@test length(sol_mc_states.times_states) == length(saveat)
302309
@test size(sol_mc_states.states) == (500, length(saveat)) # ntraj = 500
@@ -331,6 +338,23 @@ end
331338
@test_throws ArgumentError mcsolve(H, ψ0, tlist, c_ops, save_idxs = [1, 2], progress_bar = Val(false))
332339
@test_throws DimensionMismatch mcsolve(H, TESetup.ψ_wrong, tlist, c_ops, progress_bar = Val(false))
333340

341+
# test average_states, average_expect, and std_expect
342+
expvals_all = sol_mc3.expect[:, :, 2:end] # ignore testing initial time point since its standard deviation is a very small value (basically zero)
343+
stdvals = std_expect(sol_mc3)
344+
@test average_states(sol_mc) == sol_mc.states
345+
@test average_expect(sol_mc) == sol_mc.expect
346+
@test size(stdvals) == (length(e_ops), length(tlist))
347+
@test all(
348+
isapprox.(
349+
stdvals[:, 2:end], # ignore testing initial time point since its standard deviation is a very small value (basically zero)
350+
dropdims(sqrt.(mean(abs2.(expvals_all), dims = 2) .- abs2.(mean(expvals_all, dims = 2))), dims = 2);
351+
atol = 1e-6,
352+
),
353+
)
354+
@test average_expect(sol_mc_states) === nothing
355+
@test std_expect(sol_mc_states) === nothing
356+
@test_throws ArgumentError std_expect(sol_mc)
357+
334358
@testset "Memory Allocations (mcsolve)" begin
335359
ntraj = 100
336360
for keep_runs_results in (Val(false), Val(true))

0 commit comments

Comments
 (0)