From c6f91387aea9ad17baf6a38a938c3b38bbbf702d Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Sat, 12 Oct 2024 11:55:13 +0200 Subject: [PATCH] Improve mcsolve performance --- src/time_evolution/mcsolve.jl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/time_evolution/mcsolve.jl b/src/time_evolution/mcsolve.jl index e828118a9..fb08a28d1 100644 --- a/src/time_evolution/mcsolve.jl +++ b/src/time_evolution/mcsolve.jl @@ -23,6 +23,7 @@ end function LindbladJumpAffect!(integrator) internal_params = integrator.p c_ops = internal_params.c_ops + c_ops_herm = internal_params.c_ops_herm cache_mc = internal_params.cache_mc weights_mc = internal_params.weights_mc cumsum_weights_mc = internal_params.cumsum_weights_mc @@ -33,11 +34,11 @@ function LindbladJumpAffect!(integrator) ψ = integrator.u @inbounds for i in eachindex(weights_mc) - mul!(cache_mc, c_ops[i], ψ) - weights_mc[i] = real(dot(cache_mc, cache_mc)) + weights_mc[i] = real(dot(ψ, c_ops_herm[i], ψ)) end cumsum!(cumsum_weights_mc, weights_mc) - collaps_idx = getindex(1:length(weights_mc), findfirst(>(rand(traj_rng) * sum(weights_mc)), cumsum_weights_mc)) + r = rand(traj_rng) * sum(weights_mc) + collaps_idx = getindex(1:length(weights_mc), findfirst(>(r), cumsum_weights_mc)) mul!(cache_mc, c_ops[collaps_idx], ψ) normalize!(cache_mc) copyto!(integrator.u, cache_mc) @@ -237,13 +238,17 @@ function mcsolveProblem( jump_times = Vector{Float64}(undef, jump_times_which_init_size) jump_which = Vector{Int16}(undef, jump_times_which_init_size) + c_ops_data = get_data.(c_ops) + c_ops_herm_data = map(op -> op' * op, c_ops_data) + params2 = ( expvals = expvals, e_ops_mc = e_ops2, is_empty_e_ops_mc = is_empty_e_ops_mc, progr_mc = ProgressBar(length(t_l), enable = false), traj_rng = rng, - c_ops = get_data.(c_ops), + c_ops = c_ops_data, + c_ops_herm = c_ops_herm_data, cache_mc = cache_mc, weights_mc = weights_mc, cumsum_weights_mc = cumsum_weights_mc,