-
Notifications
You must be signed in to change notification settings - Fork 31
Make time evolution solvers compatible with automatic differentiation #311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I think for some of the parameters which must be used in all solvers (e.g., In this case, we don't need to define BTW, maybe call it |
|
Actually I'm trying to simplify it as much as possible. The gradient calculation of |
bf3450d to
5579426
Compare
|
With the last commit, I finally succeeded to support automatic differentiation on using QuantumToolbox
using OrdinaryDiffEq
using SciMLSensitivity
using Zygote
##
const N = 20
const F = 1
const γ = 1
const a = destroy(N)
const ψ0 = fock(N, 0)
# coef1(p, t) = p.Δ
coef1(p, t) = p[1]
QobjEvo(a' * a, coef1) + F * (a' + a)
function ss_population(Δ)
# H = Δ * a' * a + F * (a' + a)
H = QobjEvo(a' * a, coef1) + F * (a' + a)
c_ops = [sqrt(γ) * a]
tlist = range(0, 1.5, 100)
ρ_ss = sesolve(H, ψ0, tlist, progress_bar=Val(false), inplace=Val(false), params = [Δ], saveat = [tlist[end]]).states[end].data
return real(sum(ρ_ss))
end
Δ = 1.0
ss_population(Δ)
##
Zygote.gradient(ss_population, Δ) |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #311 +/- ##
==========================================
- Coverage 93.69% 93.21% -0.48%
==========================================
Files 32 36 +4
Lines 2490 2581 +91
==========================================
+ Hits 2333 2406 +73
- Misses 157 175 +18 ☔ View full report in Codecov by Sentry. |
Checklist
Thank you for contributing to
QuantumToolbox.jl! Please make sure you have finished the following tasks before opening the PR.make test.juliaformatted by running:make format.docs/folder) related to code changes were updated and able to build locally by running:make docs.CHANGELOG.mdshould be updated (regarding to the code changes) and built by running:make changelog.Request for a review after you have completed all the tasks. If you have not finished them all, you can also open a Draft Pull Request to let the others know this on-going work.
Description
With this PR I change the structure of the time evolution solver in order to support automatic differentiation.
Thanks to the SciMLSensitivity.jl package, it is possible to compute the gradient of a differential equation. It is almost straightforward to do with ODE parameters as a
Vectortype, but it is not easy to implement when we have a complicated structure of the parameters as in the current case of the package, where we have many variables, progress bar, etc inside the params.The main change here is to introduce a new struct for the parameters, instead of using the current
NamedTuple. In this way, thanks to SciMLStructures.jl, we can say which part of the structure is differentiable and which not.As a first step, I'm trying to simplify the structure of the
paramsstruct. This involves the creation of a custom struct to handle theODEProblemgenerated by functions likesesolveProblem. In this way, many variables can be removed fromparams.Currently, there are some limitations on the type of the differentiable part of the
paramsstruct, and the only supported type is theVectorone. For example, theparamskwarg in themesolvehas to be aVectorand not aNamedTuple. See this issue for more information. Nonetheless, theNamedTupletype is still supported in standard simulations, where the gradient is not needed.EDIT:
The custom struct
TimeEvolutionParametersis no longer needed. We can pass all the cache and temporary variables to their respective callbacks. In this way, theparamsvariable is only composed by the true parameters (usually aVector{Number}.To Do:
ODEProblems generated by the solversSciMLStructures.jlrules for the customparamsstructsesolvedifferentiablemesolvedifferentiablemcsolvedifferentiable (maybe in another PR?)