Skip to content

Commit dc39462

Browse files
Change order of expect_all
1 parent e3ea221 commit dc39462

File tree

4 files changed

+9
-9
lines changed

4 files changed

+9
-9
lines changed

src/time_evolution/mcsolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,12 +401,12 @@ function mcsolve(
401401
_expvals_sol_1 = _mcsolve_get_expvals(_sol_1)
402402

403403
_expvals_all = _expvals_sol_1 isa Nothing ? nothing : map(i -> _mcsolve_get_expvals(sol[:, i]), eachindex(sol))
404-
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all)
404+
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
405405
states = map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol))
406406
col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol))
407407
col_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_which, eachindex(sol))
408408

409-
expvals = _expvals_sol_1 isa Nothing ? nothing : dropdims(sum(expvals_all, dims = 3), dims = 3) ./ length(sol)
409+
expvals = _expvals_sol_1 isa Nothing ? nothing : dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
410410

411411
return TimeEvolutionMCSol(
412412
ntraj,

src/time_evolution/smesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,12 +354,12 @@ function smesolve(
354354
normalize_states = Val(false)
355355
dims = ens_prob.dimensions
356356
_expvals_all = _expvals_sol_1 isa Nothing ? nothing : map(i -> _se_me_sse_get_expvals(sol[:, i]), eachindex(sol))
357-
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all)
357+
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
358358
states = map(i -> _smesolve_generate_state.(sol[:, i].u, Ref(dims)), eachindex(sol))
359359

360360
expvals =
361361
_se_me_sse_get_expvals(_sol_1) isa Nothing ? nothing :
362-
dropdims(sum(expvals_all, dims = 3), dims = 3) ./ length(sol)
362+
dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
363363

364364
return TimeEvolutionStochasticSol(
365365
ntraj,

src/time_evolution/ssesolve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,12 @@ function ssesolve(
444444
dims = _sol_1.prob.p.Hdims
445445
_expvals_all =
446446
_expvals_sol_1 isa Nothing ? nothing : map(i -> _se_me_sse_get_expvals(sol[:, i]), eachindex(sol))
447-
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all)
447+
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP
448448
states = map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol))
449449

450450
expvals =
451451
_se_me_sse_get_expvals(_sol_1) isa Nothing ? nothing :
452-
dropdims(sum(expvals_all, dims = 3), dims = 3) ./ length(sol)
452+
dropdims(sum(expvals_all, dims = 2), dims = 2) ./ length(sol)
453453

454454
return TimeEvolutionStochasticSol(
455455
ntraj,

test/core-test/time_evolution.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -664,17 +664,17 @@
664664
@test sol_mc1.col_times sol_mc2.col_times atol = 1e-10
665665
@test sol_mc1.col_which sol_mc2.col_which atol = 1e-10
666666

667-
@test sol_mc1.expect_all sol_mc3.expect_all[:, :, 1:500] atol = 1e-10
667+
@test sol_mc1.expect_all sol_mc3.expect_all[:, 1:500, :] atol = 1e-10
668668

669669
@test sol_sse1.expect sol_sse2.expect atol = 1e-10
670670
@test sol_sse1.expect_all sol_sse2.expect_all atol = 1e-10
671671

672-
@test sol_sse1.expect_all sol_sse3.expect_all[:, :, 1:50] atol = 1e-10
672+
@test sol_sse1.expect_all sol_sse3.expect_all[:, 1:50, :] atol = 1e-10
673673

674674
@test sol_sme1.expect sol_sme2.expect atol = 1e-10
675675
@test sol_sme1.expect_all sol_sme2.expect_all atol = 1e-10
676676

677-
@test sol_sme1.expect_all sol_sme3.expect_all[:, :, 1:50] atol = 1e-10
677+
@test sol_sme1.expect_all sol_sme3.expect_all[:, 1:50, :] atol = 1e-10
678678
end
679679
end
680680

0 commit comments

Comments
 (0)