Skip to content

Commit e7c12f6

Browse files
committed
support autodiff for SteadyStateODESolver
1 parent 75f33c4 commit e7c12f6

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

src/steadystate.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,15 +214,16 @@ end
214214
function _steadystate(L::AbstractQuantumObject{SuperOperator}, solver::SteadyStateODESolver; kwargs...)
215215
ψ0 = isnothing(solver.ψ0) ? rand_ket(L.dimensions) : solver.ψ0
216216
ftype = _float_type(ψ0)
217-
tlist = [ftype(0), ftype(solver.tmax)]
217+
tmax = ftype(solver.tmax)
218+
tlist = [ftype(0), tmax]
218219

219220
# overwrite some kwargs and throw warning message to tell the users that we are ignoring these settings
220221
haskey(kwargs, :progress_bar) && @warn "Ignore keyword argument 'progress_bar' for SteadyStateODESolver"
221222
haskey(kwargs, :save_everystep) && @warn "Ignore keyword argument 'save_everystep' for SteadyStateODESolver"
222223
haskey(kwargs, :saveat) && @warn "Ignore keyword argument 'saveat' for SteadyStateODESolver"
223224
kwargs2 = merge(
224225
NamedTuple(kwargs), # we convert to NamedTuple just in case if kwargs is empty
225-
(progress_bar = Val(false), save_everystep = false, saveat = ftype[]),
226+
(progress_bar = Val(false), save_everystep = false, saveat = ftype[tmax]),
226227
)
227228

228229
# add terminate condition (callback)

test/ext-test/cpu/autodiff/autodiff.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,17 @@ function my_f_mesolve(p)
6767
return real(expect(a' * a, sol.states[end]))
6868
end
6969

70+
function my_f_steadystate(p)
71+
ρss = steadystate(
72+
L,
73+
SteadyStateODESolver(ψ0 = ψ0_mesolve, tmax = tlist_mesolve[end]);
74+
params = p,
75+
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()),
76+
)
77+
78+
return real(expect(a' * a, ρss))
79+
end
80+
7081
# Analytical solution
7182
n_ss(Δ, F, γ) = abs2(F /+ 1im * γ / 2))
7283

@@ -113,8 +124,12 @@ n_ss(Δ, F, γ) = abs2(F / (Δ + 1im * γ / 2))
113124

114125
my_f_mesolve_direct(params)
115126
my_f_mesolve(params)
127+
my_f_steadystate(params)
116128

129+
# calculate exact solution and check if steadystate works
117130
grad_exact = Zygote.gradient((p) -> n_ss(p[1], p[2], p[3]), params)[1]
131+
grad_ss = Zygote.gradient(my_f_steadystate, params)[1]
132+
@test grad_ss grad_exact atol=1e-5
118133

119134
@testset "ForwardDiff.jl" begin
120135
grad_qt = ForwardDiff.gradient(my_f_mesolve_direct, params)

0 commit comments

Comments
 (0)