-
Notifications
You must be signed in to change notification settings - Fork 31
Fix Zygote autodiff errors #530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -204,14 +204,16 @@ function mesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit | |||||||||||||||||
| ρt = map(ϕ -> QuantumObject(vec2mat(ϕ), type = Operator(), dims = prob.dimensions), sol.u) | ||||||||||||||||||
| end | ||||||||||||||||||
|
|
||||||||||||||||||
| kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | ||||||||||||||||||
|
||||||||||||||||||
| kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | |
| if sol.prob.kwargs isa NamedTuple | |
| kwargs = sol.prob.kwargs | |
| elseif all(k -> k isa Symbol, keys(sol.prob.kwargs)) | |
| kwargs = NamedTuple(sol.prob.kwargs) | |
| else | |
| error("sol.prob.kwargs contains non-symbol keys and cannot be safely converted to a NamedTuple.") | |
| end |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -154,14 +154,16 @@ function sesolve(prob::TimeEvolutionProblem, alg::OrdinaryDiffEqAlgorithm = Tsit | |||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ψt = map(ϕ -> QuantumObject(ϕ, type = Ket(), dims = prob.dimensions), sol.u) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
| kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | |
| # Safely convert to NamedTuple for Zygote.jl compatibility | |
| if isa(sol.prob.kwargs, NamedTuple) | |
| kwargs = sol.prob.kwargs | |
| elseif isa(sol.prob.kwargs, Dict) | |
| # Only convert if all keys are symbols | |
| if all(isa(k, Symbol) for k in keys(sol.prob.kwargs)) | |
| kwargs = NamedTuple(sol.prob.kwargs) | |
| else | |
| error("sol.prob.kwargs contains non-symbol keys and cannot be converted to NamedTuple") | |
| end | |
| else | |
| # Fallback: try conversion, or use empty NamedTuple if not possible | |
| try | |
| kwargs = NamedTuple(sol.prob.kwargs) | |
| catch | |
| kwargs = (;) | |
| end | |
| end |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -426,6 +426,8 @@ function smesolve( | |||||||||||||||||
| _m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSMESolve), eachindex(sol)) | ||||||||||||||||||
| m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2) # Stack on dimension 2 to align with QuTiP | ||||||||||||||||||
|
|
||||||||||||||||||
| kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | ||||||||||||||||||
|
||||||||||||||||||
| kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | |
| if _sol_1.prob.kwargs isa NamedTuple | |
| kwargs = _sol_1.prob.kwargs | |
| elseif all(k -> k isa Symbol, keys(_sol_1.prob.kwargs)) | |
| kwargs = NamedTuple(_sol_1.prob.kwargs) | |
| else | |
| kwargs = (;) | |
| end |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -418,6 +418,8 @@ function ssesolve( | |||||||||||
| _m_expvals_sol_1 isa Nothing ? nothing : map(i -> _get_m_expvals(sol[:, i], SaveFuncSSESolve), eachindex(sol)) | ||||||||||||
| m_expvals = _m_expvals isa Nothing ? nothing : stack(_m_expvals, dims = 2) | ||||||||||||
|
|
||||||||||||
| kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | ||||||||||||
|
||||||||||||
| kwargs = NamedTuple(_sol_1.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility | |
| kwargs = isa(_sol_1.prob.kwargs, NamedTuple) ? _sol_1.prob.kwargs : | |
| NamedTuple{Tuple(k for k in keys(_sol_1.prob.kwargs) if isa(k, Symbol))}( | |
| Dict(k => v for (k, v) in _sol_1.prob.kwargs if isa(k, Symbol)) | |
| ) # Safely convert to NamedTuple for Zygote.jl compatibility |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NamedTuple conversion could fail if
_sol_1.prob.kwargsis already a NamedTuple or contains non-symbol keys. Consider using a safer conversion or checking the type first.