Skip to content

Commit 2463de8

Browse files
add SciMLStructures and relax params type
1 parent 4aff2a5 commit 2463de8

File tree

6 files changed

+75
-56
lines changed

6 files changed

+75
-56
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2121
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2222
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2323
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
24+
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
2425
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2526
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2627
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
@@ -53,6 +54,7 @@ Random = "1"
5354
Reexport = "1"
5455
SciMLBase = "2"
5556
SciMLOperators = "0.3"
57+
SciMLStructures = "1.5.0"
5658
SparseArrays = "1"
5759
SpecialFunctions = "2"
5860
StaticArraysCore = "1"

src/QuantumToolbox.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ import SciMLOperators:
4545
IdentityOperator,
4646
update_coefficients!,
4747
concretize
48+
import SciMLStructures: isscimlstructure, ismutablescimlstructure, hasportion, canonicalize, replace, replace!, Tunable
4849
import LinearSolve: LinearProblem, SciMLLinearSolveAlgorithm, KrylovJL_MINRES, KrylovJL_GMRES
4950
import DiffEqBase: get_tstops
5051
import DiffEqCallbacks: PeriodicCallback, PresetTimeCallback, TerminateSteadyState
@@ -87,6 +88,7 @@ include("qobj/superoperators.jl")
8788
include("qobj/synonyms.jl")
8889

8990
# time evolution
91+
include("time_evolution/time_evo_parameters.jl")
9092
include("time_evolution/time_evolution.jl")
9193
include("time_evolution/mesolve.jl")
9294
include("time_evolution/lr_mesolve.jl")

src/time_evolution/sesolve.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ _sesolve_make_U_QobjEvo(H) = QobjEvo(H, -1im)
6262
ψ0::QuantumObject{DT2,KetQuantumObject},
6363
tlist::AbstractVector;
6464
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
65-
params::NamedTuple = NamedTuple(),
65+
params::Union{NamedTuple, AbstractVector} = NamedTuple(),
6666
progress_bar::Union{Val,Bool} = Val(true),
6767
inplace::Union{Val,Bool} = Val(true),
6868
kwargs...,
@@ -80,7 +80,7 @@ Generate the ODEProblem for the Schrödinger time evolution of a quantum system:
8080
- `ψ0`: Initial state of the system ``|\psi(0)\rangle``.
8181
- `tlist`: List of times at which to save either the state or the expectation values of the system.
8282
- `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`.
83-
- `params`: `NamedTuple` of parameters to pass to the solver.
83+
- `params`: `NamedTuple` or `AbstractVector` of parameters to pass to the solver.
8484
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
8585
- `inplace`: Whether to use the inplace version of the ODEProblem. The default is `Val(true)`.
8686
- `kwargs`: The keyword arguments for the ODEProblem.
@@ -101,7 +101,7 @@ function sesolveProblem(
101101
ψ0::QuantumObject{DT2,KetQuantumObject},
102102
tlist::AbstractVector;
103103
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
104-
params::NamedTuple = NamedTuple(),
104+
params::Union{NamedTuple,AbstractVector} = NamedTuple(),
105105
progress_bar::Union{Val,Bool} = Val(true),
106106
inplace::Union{Val,Bool} = Val(true),
107107
kwargs...,
@@ -148,7 +148,7 @@ end
148148
tlist::AbstractVector;
149149
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
150150
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
151-
params::NamedTuple = NamedTuple(),
151+
params::Union{NamedTuple, AbstractVector} = NamedTuple(),
152152
progress_bar::Union{Val,Bool} = Val(true),
153153
inplace::Union{Val,Bool} = Val(true),
154154
kwargs...,
@@ -167,7 +167,7 @@ Time evolution of a closed quantum system using the Schrödinger equation:
167167
- `tlist`: List of times at which to save either the state or the expectation values of the system.
168168
- `alg`: The algorithm for the ODE solver. The default is `Tsit5()`.
169169
- `e_ops`: List of operators for which to calculate expectation values. It can be either a `Vector` or a `Tuple`.
170-
- `params`: `NamedTuple` of parameters to pass to the solver.
170+
- `params`: `NamedTuple` or `AbstractVector` of parameters to pass to the solver.
171171
- `progress_bar`: Whether to show the progress bar. Using non-`Val` types might lead to type instabilities.
172172
- `inplace`: Whether to use the inplace version of the ODEProblem. The default is `Val(true)`.
173173
- `kwargs`: The keyword arguments for the ODEProblem.
@@ -190,12 +190,21 @@ function sesolve(
190190
tlist::AbstractVector;
191191
alg::OrdinaryDiffEqAlgorithm = Tsit5(),
192192
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
193-
params::NamedTuple = NamedTuple(),
193+
params::Union{NamedTuple,AbstractVector} = NamedTuple(),
194194
progress_bar::Union{Val,Bool} = Val(true),
195195
inplace::Union{Val,Bool} = Val(true),
196196
kwargs...,
197197
) where {DT1,DT2}
198-
prob = sesolveProblem(H, ψ0, tlist; e_ops = e_ops, params = params, progress_bar = progress_bar, inplace = inplace, kwargs...)
198+
prob = sesolveProblem(
199+
H,
200+
ψ0,
201+
tlist;
202+
e_ops = e_ops,
203+
params = params,
204+
progress_bar = progress_bar,
205+
inplace = inplace,
206+
kwargs...,
207+
)
199208

200209
return sesolve(prob, alg)
201210
end
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# This function should be implemented after Julia v1.12
2+
Base.@constprop :aggressive function _delete_field(a::NamedTuple{an}, field::Symbol) where {an}
3+
names = Base.diff_names(an, (field,))
4+
return NamedTuple{names}(a)
5+
end
6+
7+
struct QuantumTimeEvoParameters{TE<:AbstractMatrix,PT<:ProgressBar,ParT}
8+
expvals::TE
9+
progr::PT
10+
params::ParT
11+
12+
function QuantumTimeEvoParameters(expvals, progr, params)
13+
_expvals = expvals
14+
_progr = progr
15+
_params = params
16+
17+
# We replace the fields if they are aleady in the `params` struct
18+
# Then, we remove them from the `params` struct
19+
if :expvals fieldnames(typeof(_params))
20+
_expvals = _params.expvals
21+
_params = _delete_field(_params, :expvals)
22+
end
23+
if :progr fieldnames(typeof(_params))
24+
_progr = _params.progr
25+
_params = _delete_field(_params, :progr)
26+
end
27+
28+
return new{typeof(_expvals),typeof(_progr),typeof(_params)}(_expvals, _progr, _params)
29+
end
30+
end
31+
32+
#=
33+
By defining a custom `getproperty` method for the `QuantumTimeEvoParameters` struct, we can access the fields of `params` directly.
34+
=#
35+
function Base.getproperty(obj::QuantumTimeEvoParameters, field::Symbol)
36+
if field fieldnames(typeof(obj))
37+
getfield(obj, field)
38+
elseif field fieldnames(typeof(obj.params))
39+
getfield(obj.params, field)
40+
else
41+
throw(KeyError("Field $field not found in QuantumTimeEvoParameters or params."))
42+
end
43+
end
44+
45+
#=
46+
It also supports `params` as a `Vector`, so we implement the `getindex` method for the `QuantumTimeEvoParameters` struct.
47+
=#
48+
Base.getindex(obj::QuantumTimeEvoParameters, i::Int) = getindex(obj.params, i)
49+
50+
Base.length(obj::QuantumTimeEvoParameters) = length(obj.params)
51+
52+
function Base.merge(a::QuantumTimeEvoParameters, b::NamedTuple)
53+
return QuantumTimeEvoParameters(a.expvals, a.progr, merge(a.params, b))
54+
end

src/time_evolution/time_evolution.jl

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,54 +5,6 @@ export liouvillian_floquet, liouvillian_generalized
55
const DEFAULT_ODE_SOLVER_OPTIONS = (abstol = 1e-8, reltol = 1e-6, save_everystep = false, save_end = true)
66
const DEFAULT_SDE_SOLVER_OPTIONS = (abstol = 1e-2, reltol = 1e-2, save_everystep = false, save_end = true)
77

8-
# This function should be implemented after Julia v1.12
9-
Base.@constprop :aggressive function _delete_field(a::NamedTuple{an}, field::Symbol) where {an}
10-
names = Base.diff_names(an, (field,))
11-
return NamedTuple{names}(a)
12-
end
13-
14-
struct QuantumTimeEvoParameters{TE<:AbstractMatrix,PT<:ProgressBar,ParT}
15-
expvals::TE
16-
progr::PT
17-
params::ParT
18-
19-
function QuantumTimeEvoParameters(expvals, progr, params)
20-
_expvals = expvals
21-
_progr = progr
22-
_params = params
23-
24-
# We replace the fields if they are aleady in the `params` struct
25-
# Then, we remove them from the `params` struct
26-
if :expvals fieldnames(typeof(_params))
27-
_expvals = _params.expvals
28-
_params = _delete_field(_params, :expvals)
29-
end
30-
if :progr fieldnames(typeof(_params))
31-
_progr = _params.progr
32-
_params = _delete_field(_params, :progr)
33-
end
34-
35-
return new{typeof(_expvals),typeof(_progr),typeof(_params)}(_expvals, _progr, _params)
36-
end
37-
end
38-
39-
#=
40-
By defining a custom `getproperty` method for the `QuantumTimeEvoParameters` struct, we can access the fields of `params` directly.
41-
=#
42-
function Base.getproperty(obj::QuantumTimeEvoParameters, field::Symbol)
43-
if field fieldnames(typeof(obj))
44-
getfield(obj, field)
45-
elseif field fieldnames(typeof(obj.params))
46-
getfield(obj.params, field)
47-
else
48-
throw(KeyError("Field $field not found in QuantumTimeEvoParameters or params."))
49-
end
50-
end
51-
52-
function Base.merge(a::QuantumTimeEvoParameters, b::NamedTuple)
53-
return QuantumTimeEvoParameters(a.expvals, a.progr, merge(a.params, b))
54-
end
55-
568
struct QuantumTimeEvoProblem{PT,TT<:AbstractVector,DT<:AbstractVector}
579
prob::PT
5810
times::TT

test/core-test/time_evolution.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
sol2 = sesolve(H, psi0, t_l, progress_bar = Val(false))
1717
sol3 = sesolve(H, psi0, t_l, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
1818
sol_string = sprint((t, s) -> show(t, "text/plain", s), sol)
19-
@test prob.f.f isa MatrixOperator
19+
@test prob.prob.f.f isa MatrixOperator
2020
@test sum(abs.(sol.expect[1, :] .- sin.(η * t_l) .^ 2)) / length(t_l) < 0.1
2121
@test length(sol.times) == length(t_l)
2222
@test length(sol.states) == 1

0 commit comments

Comments
 (0)