-
Notifications
You must be signed in to change notification settings - Fork 31
Fix Zygote autodiff errors #530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #530 +/- ##
==========================================
+ Coverage 93.09% 94.14% +1.05%
==========================================
Files 51 51
Lines 3549 3554 +5
==========================================
+ Hits 3304 3346 +42
+ Misses 245 208 -37 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes autodiff errors with Zygote.jl by modifying how kwargs are accessed in time evolution solvers. The changes address compatibility issues when using automatic differentiation with the quantum time evolution functions.
- Converts
sol.prob.kwargsfrom dictionary access to NamedTuple before accessingabstolandreltol - Applies the same fix consistently across all time evolution solver functions
- Adds explanatory comments indicating the change is for Zygote.jl compatibility
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| src/time_evolution/sesolve.jl | Converts kwargs to NamedTuple before accessing abstol/reltol |
| src/time_evolution/mesolve.jl | Converts kwargs to NamedTuple before accessing abstol/reltol |
| src/time_evolution/mcsolve.jl | Converts kwargs to NamedTuple before accessing abstol/reltol |
| src/time_evolution/ssesolve.jl | Converts kwargs to NamedTuple before accessing abstol/reltol |
| src/time_evolution/smesolve.jl | Converts kwargs to NamedTuple before accessing abstol/reltol |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| _m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSSESolve), eachindex(sol)) | ||
| m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2) | ||
|
|
||
| kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility |
Copilot
AI
Aug 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NamedTuple conversion could fail if _sol_1.prob.kwargs is already a NamedTuple or contains non-symbol keys. Consider using a safer conversion or checking the type first.
| kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | |
| kwargs = isa(_sol_1.prob.kwargs, NamedTuple) ? _sol_1.prob.kwargs : | |
| NamedTuple{Tuple(k for k in keys(_sol_1.prob.kwargs) if isa(k, Symbol))}( | |
| Dict(k => v for (k, v) in _sol_1.prob.kwargs if isa(k, Symbol)) | |
| ) # Safely convert to NamedTuple for Zygote.jl compatibility |
| _m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSMESolve), eachindex(sol)) | ||
| m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2) # Stack on dimension 2 to align with QuTiP | ||
|
|
||
| kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility |
Copilot
AI
Aug 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NamedTuple conversion could fail if _sol_1.prob.kwargs is already a NamedTuple or contains non-symbol keys. Consider using a safer conversion or checking the type first.
| kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | |
| if _sol_1.prob.kwargs isa NamedTuple | |
| kwargs = _sol_1.prob.kwargs | |
| elseif all(k -> k isa Symbol, keys(_sol_1.prob.kwargs)) | |
| kwargs = NamedTuple(_sol_1.prob.kwargs) | |
| else | |
| kwargs = (;) | |
| end |
|
|
||
| ψt = map(ϕ -> QuantumObject(ϕ, type = Ket(), dims = prob.dimensions), sol.u) | ||
|
|
||
| kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility |
Copilot
AI
Aug 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NamedTuple conversion could fail if sol.prob.kwargs is already a NamedTuple or contains non-symbol keys. Consider using a safer conversion or checking the type first.
| kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | |
| # Safely convert to NamedTuple for Zygote.jl compatibility | |
| if isa(sol.prob.kwargs, NamedTuple) | |
| kwargs = sol.prob.kwargs | |
| elseif isa(sol.prob.kwargs, Dict) | |
| # Only convert if all keys are symbols | |
| if all(isa(k, Symbol) for k in keys(sol.prob.kwargs)) | |
| kwargs = NamedTuple(sol.prob.kwargs) | |
| else | |
| error("sol.prob.kwargs contains non-symbol keys and cannot be converted to NamedTuple") | |
| end | |
| else | |
| # Fallback: try conversion, or use empty NamedTuple if not possible | |
| try | |
| kwargs = NamedTuple(sol.prob.kwargs) | |
| catch | |
| kwargs = (;) | |
| end | |
| end |
| ρt = map(ϕ -> QuantumObject(vec2mat(ϕ), type = Operator(), dims = prob.dimensions), sol.u) | ||
| end | ||
|
|
||
| kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility |
Copilot
AI
Aug 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NamedTuple conversion could fail if sol.prob.kwargs is already a NamedTuple or contains non-symbol keys. Consider using a safer conversion or checking the type first.
| kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | |
| if sol.prob.kwargs isa NamedTuple | |
| kwargs = sol.prob.kwargs | |
| elseif all(k -> k isa Symbol, keys(sol.prob.kwargs)) | |
| kwargs = NamedTuple(sol.prob.kwargs) | |
| else | |
| error("sol.prob.kwargs contains non-symbol keys and cannot be safely converted to a NamedTuple.") | |
| end |
| col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol)) | ||
| col_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_which, eachindex(sol)) | ||
|
|
||
| kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility |
Copilot
AI
Aug 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NamedTuple conversion could fail if _sol_1.prob.kwargs is already a NamedTuple or contains non-symbol keys. Consider using a safer conversion or checking the type first.
| kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | |
| if isa(_sol_1.prob.kwargs, NamedTuple) | |
| kwargs = _sol_1.prob.kwargs | |
| else | |
| # Convert to NamedTuple, ensuring keys are symbols | |
| kwargs = NamedTuple{Tuple(Symbol.(keys(_sol_1.prob.kwargs)))}(values(_sol_1.prob.kwargs)) | |
| end |
Checklist
Thank you for contributing to
QuantumToolbox.jl! Please make sure you have finished the following tasks before opening the PR.make test.juliaformatted by running:make format.docs/folder) related to code changes were updated and able to build locally by running:make docs.CHANGELOG.mdshould be updated (regarding to the code changes) and built by running:make changelog.Request for a review after you have completed all the tasks. If you have not finished them all, you can also open a Draft Pull Request to let the others know this on-going work.
Description
This PR fixes the recent errors on autodiff with Zygote.jl.