Skip to content

Commit 7cd0e0a

Browse files
Merge pull request #5 from una-auxme/dev
v0.2.0
2 parents 826321b + f94ca0f commit 7cd0e0a

File tree

4 files changed

+39
-18
lines changed

4 files changed

+39
-18
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphNetCore"
22
uuid = "7809f980-de1b-4f9a-8451-85f041491431"
33
authors = ["JT <julian.trommer@uni-a.de>"]
4-
version = "0.1.1"
4+
version = "0.2.0"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -29,11 +29,11 @@ KernelAbstractions = "0.9"
2929
Lux = "0.5"
3030
LuxCUDA = "0.3.1"
3131
NNlib = "0.9"
32+
Random = "1"
33+
Statistics = "1"
3234
Tullio = "0.3.7"
3335
Zygote = "0.6"
3436
cuDNN = "1.1.1"
35-
Random = "1"
36-
Statistics = "1"
3737
julia = "1.9"
3838

3939
[extras]

src/graph_network.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ mutable struct GraphNetwork
3131
model
3232
ps
3333
st
34-
e_norm::NormaliserOnline
34+
e_norm::Union{NormaliserOffline, NormaliserOnline}
3535
n_norm::Dict{String, Union{NormaliserOffline, NormaliserOnline}}
36-
o_norm::NormaliserOnline
36+
o_norm::Dict{String, Union{NormaliserOffline, NormaliserOnline}}
3737
end
3838

3939
function build_mlp(input_size::T, latent_size::T, output_size::T, hidden_layers::T; layer_norm=true, dev=cpu) where T <: Integer
@@ -198,22 +198,20 @@ Loads the [`GraphNetwork`](@ref) from the latest checkpoint at the given path.
198198
- `df_train`: [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl) DataFrame containing the train losses at the checkpoints.
199199
- `df_valid`: [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl) DataFrame containing the validation losses at the checkpoints (only improvements are saved).
200200
"""
201-
function load(quantities, dims, norms, output, message_steps, ls, hl, opt, device::Function, path::String)
201+
function load(quantities, dims, e_norms, n_norms, o_norms, output, message_steps, ls, hl, opt, device::Function, path::String)
202202
if isfile(joinpath(path, "checkpoints"))
203203
step = parse(Int, readlines(joinpath(path, "checkpoints"))[end])
204204
ps_data, ps_axes, st, e_norm, n_norm, o_norm, opt_state, df_train, df_valid = load(joinpath(path, "checkpoint_$step.jld2"), "ps_data", "ps_axes", "st", "e_norm", "n_norm", "o_norm", "opt_state", "df_train", "df_valid")
205205

206206
ps = ComponentArray(ps_data, ps_axes) |> device
207207
st = st |> device
208208

209-
en = NormaliserOnline(e_norm, device)
210-
for (k, n) in n_norm
211-
norms[k] = NormaliserOnline(n, device)
212-
end
213-
on = NormaliserOnline(o_norm, device)
209+
en = deserialize(e_norm, device)
210+
nn = deserialize(n_norm, device)
211+
on = deserialize(o_norm, device)
214212

215213
model = build_model(quantities, dims, output, message_steps, ls, hl, device)
216-
gn = GraphNetwork(model, ps, st, en, norms, on)
214+
gn = GraphNetwork(model, ps, st, en, nn, on)
217215

218216
if !isnothing(opt)
219217
return gn, nothing, df_train, df_valid
@@ -227,7 +225,7 @@ function load(quantities, dims, norms, output, message_steps, ls, hl, opt, devic
227225
ps = ComponentArray(ps) |> device
228226
st = st |> device
229227

230-
gn = GraphNetwork(model, ps, st, NormaliserOnline(dims + 1, device), norms, NormaliserOnline(output, device))
228+
gn = GraphNetwork(model, ps, st, e_norms, n_norms, o_norms)
231229

232230
return gn, nothing, DataFrame(step=Integer[], loss=Float32[]), DataFrame(step=Integer[], loss=Float32[])
233231
end

src/normaliser.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ function NormaliserOffline(data_min::Float32, data_max::Float32)
2626
NormaliserOffline(data_min, data_max, 0.0f0, 1.0f0)
2727
end
2828

29-
(n::NormaliserOffline)(F, is_training = false) = minmaxnorm(F, n.data_min, n.data_max, n.target_min, n.target_max)
29+
function NormaliserOffline(d::Dict{String, Any})
30+
NormaliserOffline(d["data_min"], d["data_max"], d["target_min"], d["target_max"])
31+
end
32+
33+
(n::NormaliserOffline)(F) = minmaxnorm(F, n.data_min, n.data_max, n.target_min, n.target_max)
3034

3135
"""
3236
inverse_data(n::NormaliserOffline, data)
@@ -153,9 +157,7 @@ end
153157
function serialize(ns::Dict{String, Union{NormaliserOffline, NormaliserOnline}})
154158
result = Dict{String, Any}()
155159
for (k, n) in ns
156-
if typeof(n) == NormaliserOnline
157-
result[k] = serialize(n)
158-
end
160+
result[k] = serialize(n)
159161
end
160162
return result
161163
end
@@ -170,3 +172,24 @@ function serialize(n::NormaliserOnline)
170172
"acc_sum_squared" => cpu_device()(n.acc_sum_squared)
171173
)
172174
end
175+
176+
function serialize(n::NormaliserOffline)
177+
return Dict{String, Any}(
178+
"data_min" => n.data_min,
179+
"data_max" => n.data_max,
180+
"target_min" => n.target_min,
181+
"target_max" => n.target_max
182+
)
183+
end
184+
185+
function deserialize(n::Dict{String, Any}, device::Function)
186+
if haskey(n, "max_accumulations")
187+
return NormaliserOnline(n, device)
188+
elseif haskey(n, "data_min")
189+
return NormaliserOffline(n)
190+
else
191+
features = keys(n)
192+
norms = deserialize.(values(n), device)
193+
return Dict{String, Union{NormaliserOffline, NormaliserOnline}}(features .=> norms)
194+
end
195+
end

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ Calculates the mean squared error of the given arguments with [Tullio](https://g
104104
- The calculated mean squared error.
105105
"""
106106
mse_reduce(target, output) = begin
107-
@assert ndims(target) == 2 && ndims(output) == 2 "Only supported dimension is 2: dims = (target => $(dims(target)), output => $(dims(output))"
107+
@assert ndims(target) == 2 && ndims(output) == 2 "Only supported number of dimensions is 2: dims = (target => $(ndims(target)), output => $(ndims(output)))"
108108
@tullio R[x] := (target[y, x] - output[y, x]) ^ 2
109109
end
110110

0 commit comments

Comments
 (0)