diff --git a/src/time_evolution/mcsolve.jl b/src/time_evolution/mcsolve.jl index ddd4379cb..396fecafd 100644 --- a/src/time_evolution/mcsolve.jl +++ b/src/time_evolution/mcsolve.jl @@ -414,6 +414,8 @@ function mcsolve( col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol)) col_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_which, eachindex(sol)) + kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility + return TimeEvolutionMCSol( ntraj, ens_prob_mc.times, @@ -424,7 +426,7 @@ function mcsolve( col_which, sol.converged, _sol_1.alg, - _sol_1.prob.kwargs[:abstol], - _sol_1.prob.kwargs[:reltol], + kwargs.abstol, + kwargs.reltol, ) end diff --git a/src/time_evolution/mesolve.jl b/src/time_evolution/mesolve.jl index 1ab033bd3..8e790eeea 100644 --- a/src/time_evolution/mesolve.jl +++ b/src/time_evolution/mesolve.jl @@ -204,6 +204,8 @@ function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit ρt = map(ϕ -> QuantumObject(vec2mat(ϕ), type = Operator(), dims = prob.dimensions), sol.u) end + kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility + return TimeEvolutionSol( prob.times, sol.t, @@ -211,7 +213,7 @@ function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit _get_expvals(sol, SaveFuncMESolve), sol.retcode, sol.alg, - sol.prob.kwargs[:abstol], - sol.prob.kwargs[:reltol], + kwargs.abstol, + kwargs.reltol, ) end diff --git a/src/time_evolution/sesolve.jl b/src/time_evolution/sesolve.jl index b447307a3..c30302e7d 100644 --- a/src/time_evolution/sesolve.jl +++ b/src/time_evolution/sesolve.jl @@ -154,6 +154,8 @@ function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit ψt = map(ϕ -> QuantumObject(ϕ, type = Ket(), dims = prob.dimensions), sol.u) + kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility + return TimeEvolutionSol( prob.times, sol.t, @@ -161,7 +163,7 @@ function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit _get_expvals(sol, SaveFuncSESolve), sol.retcode, sol.alg, - sol.prob.kwargs[:abstol], - sol.prob.kwargs[:reltol], + kwargs.abstol, + kwargs.reltol, ) end diff --git a/src/time_evolution/smesolve.jl b/src/time_evolution/smesolve.jl index 1f1aa3363..a09ba9437 100644 --- a/src/time_evolution/smesolve.jl +++ b/src/time_evolution/smesolve.jl @@ -426,6 +426,8 @@ function smesolve( _m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSMESolve), eachindex(sol)) m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2) # Stack on dimension 2 to align with QuTiP + kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility + return TimeEvolutionStochasticSol( ntraj, ens_prob.times, @@ -435,7 +437,7 @@ function smesolve( m_expvals, # Measurement expectation values sol.converged, _sol_1.alg, - _sol_1.prob.kwargs[:abstol], - _sol_1.prob.kwargs[:reltol], + kwargs.abstol, + kwargs.reltol, ) end diff --git a/src/time_evolution/ssesolve.jl b/src/time_evolution/ssesolve.jl index 962a831ad..28625e6c4 100644 --- a/src/time_evolution/ssesolve.jl +++ b/src/time_evolution/ssesolve.jl @@ -418,6 +418,8 @@ function ssesolve( _m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSSESolve), eachindex(sol)) m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2) + kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility + return TimeEvolutionStochasticSol( ntraj, ens_prob.times, @@ -427,7 +429,7 @@ function ssesolve( m_expvals, # Measurement expectation values sol.converged, _sol_1.alg, - _sol_1.prob.kwargs[:abstol], - _sol_1.prob.kwargs[:reltol], + kwargs.abstol, + kwargs.reltol, ) end