Skip to content

Commit 5579426

Browse files
Add SciMLStructures.jl methods
1 parent 75acbd8 commit 5579426

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

src/time_evolution/time_evo_parameters.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,49 @@ Base.length(obj::TimeEvolutionParameters) = length(obj.params)
4242
# return TimeEvolutionParameters(merge(a.params, b), a.expvals, a.progr, a.mcsolve_params)
4343
# end
4444

45+
########## Mark the struct as a SciMLStructure ##########
46+
# The NamedTuple `params` case still doesn't work, and it should be put as a `Vector` instead
47+
48+
isscimlstructure(::TimeEvolutionParameters) = true
49+
# ismutablescimlstructure(::TimeEvolutionParameters{ParT}) where {ParT<:NamedTuple} = false
50+
ismutablescimlstructure(::TimeEvolutionParameters{ParT}) where {ParT<:AbstractVector} = true
51+
52+
hasportion(::Tunable, ::TimeEvolutionParameters) = true
53+
54+
function _vectorize_params(p::TimeEvolutionParameters{ParT}) where {ParT<:NamedTuple}
55+
buffer = isempty(p.params) ? eltype(p.expvals)[] : collect(values(p.params))
56+
return (buffer, false)
57+
end
58+
_vectorize_params(p::TimeEvolutionParameters{ParT}) where {ParT<:AbstractVector} = (p.params, true)
59+
60+
function canonicalize(::Tunable, p::TimeEvolutionParameters)
61+
buffer, aliases = _vectorize_params(p) # We are assuming that the values have the same type
62+
63+
# repack takes a new vector of the same length as `buffer`, and constructs
64+
# a new `TimeEvolutionParameters` object using the values from the new vector for tunables
65+
# and retaining old values for other parameters. This is exactly what replace does,
66+
# so we can use that instead.
67+
repack = let p = p
68+
repack(newbuffer) = replace(Tunable(), p, newbuffer)
69+
end
70+
# the canonicalized vector, the repack function, and a boolean indicating
71+
# whether the buffer aliases values in the parameter object
72+
return buffer, repack, aliases
73+
end
74+
75+
function replace(::Tunable, p::TimeEvolutionParameters{ParT}, newbuffer) where {ParT<:NamedTuple}
76+
@assert length(newbuffer) == length(p.params)
77+
new_params = NamedTuple{keys(p.params)}(Tuple(newbuffer))
78+
return TimeEvolutionParameters(new_params, p.expvals, p.progr, p.mcsolve_params)
79+
end
80+
81+
function replace(::Tunable, p::TimeEvolutionParameters{ParT}, newbuffer) where {ParT<:AbstractVector}
82+
@assert length(newbuffer) == length(p.params)
83+
return TimeEvolutionParameters(newbuffer, p.expvals, p.progr, p.mcsolve_params)
84+
end
85+
86+
function replace!(::Tunable, p::TimeEvolutionParameters{ParT}, newbuffer) where {ParT<:AbstractVector}
87+
@assert length(newbuffer) == length(p.params)
88+
copyto!(p.params, newbuffer)
89+
return p
90+
end

0 commit comments

Comments
 (0)