You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/graph_network.jl
+15-15Lines changed: 15 additions & 15 deletions
Original file line number
Diff line number
Diff line change
@@ -19,7 +19,7 @@ include("graph_net_blocks.jl")
19
19
The central data structure that contains the neural network and the normalisers corresponding to the components of the GNN (edge features, node features and output).
20
20
21
21
# Arguments
22
-
- `model`: The Enocde-Process-Decode model as a [Lux.Chain](@ref).
22
+
- `model`: The Enocde-Process-Decode model as a [Lux](https://github.com/LuxDL/Lux.jl) Chain.
23
23
- `ps`: Parameters of the model.
24
24
- `st`: State of the model.
25
25
- `e_norm`: Normaliser for the edge features of the GNN.
Creates a checkpoint of the [GraphNetCore.GraphNetwork](@ref) at the given training step.
127
+
Creates a checkpoint of the [`GraphNetwork`](@ref) at the given training step.
128
128
129
129
# Arguments
130
-
- `gn`: The [GraphNetCore.GraphNetwork](@ref) that a checkpoint is created of.
130
+
- `gn`: The [`GraphNetwork`](@ref) that a checkpoint is created of.
131
131
- `opt_state`: State of the optimiser.
132
-
- `df_train`: [DataFrames.DataFram](@ref) that stores the train losses at the checkpoints.
133
-
- `df_valid`: [DataFrames.DataFram](@ref) that stores the validation losses at the checkpoints (only improvements are saved).
132
+
- `df_train`: [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl) DataFrame that stores the train losses at the checkpoints.
133
+
- `df_valid`: [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl) DataFrame that stores the validation losses at the checkpoints (only improvements are saved).
134
134
- `step`: Current training step where the checkpoint is created.
135
135
- `train_loss`: Current training loss.
136
136
- `path`: Path to the folder where checkpoints are saved.
Loads the [GraphNetCore.GraphNetwork](@ref) from the latest checkpoint at the given path.
181
+
Loads the [`GraphNetwork`](@ref) from the latest checkpoint at the given path.
182
182
183
183
# Arguments
184
184
- `quantities`: Sum of dimensions of each node feature.
@@ -189,14 +189,14 @@ Loads the [GraphNetCore.GraphNetwork](@ref) from the latest checkpoint at the gi
189
189
- `ls`: Size of hidden layers.
190
190
- `hl`: Number of hidden layers.
191
191
- `opt`: Optimiser that is used for training. Set this to `nothing` if you want to use the optimiser from the checkpoint.
192
-
- `device`: Device where the model should be loaded (see [Lux.gpu_device()](@ref) and [Lux.cpu_device()](@ref)).
192
+
- `device`: Device where the model should be loaded (see [Lux GPU Management](https://lux.csail.mit.edu/dev/manual/gpu_management#gpu-management)).
193
193
- `path`: Path to the folder where the checkpoint is.
194
194
195
195
# Returns
196
-
- `gn`: The loaded [GraphNetCore.GraphNetwork](@ref) from the checkpoint.
196
+
- `gn`: The loaded [`GraphNetwork`](@ref) from the checkpoint.
197
197
- `opt_state`: The loaded optimiser state. Is nothing if no checkpoint was found or an optimiser was passed as an argument.
198
-
- `df_train`: [DataFrames.DataFram](@ref) containing the train losses at the checkpoints.
199
-
- `df_valid`: [DataFrames.DataFram](@ref) containing the validation losses at the checkpoints (only improvements are saved).
198
+
- `df_train`: [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl) DataFrame containing the train losses at the checkpoints.
199
+
- `df_valid`: [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl) DataFrame containing the validation losses at the checkpoints (only improvements are saved).
Copy file name to clipboardExpand all lines: src/normaliser.jl
+5-5Lines changed: 5 additions & 5 deletions
Original file line number
Diff line number
Diff line change
@@ -34,7 +34,7 @@ end
34
34
Inverses the normalised data.
35
35
36
36
# Arguments
37
-
- `n`: The used [GraphNetCore.NormaliserOffline](@ref).
37
+
- `n`: The used [`NormaliserOffline`](@ref).
38
38
- `data`: Data to be converted back.
39
39
40
40
# Returns
@@ -75,7 +75,7 @@ It is recommended to use offline normalization since the minimum and maximum do
75
75
76
76
# Arguments
77
77
- `dims`: Dimension of the quantity to normalize.
78
-
- `device`: Device where the Normaliser should be loaded (see [Lux.gpu_device()](@ref) and [Lux.cpu_device()](@ref)).
78
+
- `device`: Device where the Normaliser should be loaded (see [Lux GPU Management](https://lux.csail.mit.edu/dev/manual/gpu_management#gpu-management)).
79
79
80
80
# Keyword Arguments
81
81
- `max_acc = 10f6`: Maximum number of accumulation steps.
@@ -92,8 +92,8 @@ Online normalization if the minimum and maximum of the quantity is not known.
92
92
It is recommended to use offline normalization since the minimum and maximum do not need to be inferred from data.
93
93
94
94
# Arguments
95
-
- `d`: Dictionary containing the fields of the struct [GraphNetCore.NormaliserOnline](@ref).
96
-
- `device`: Device where the Normaliser should be loaded (see [Lux.gpu_device()](@ref) and [Lux.cpu_device()](@ref)).
95
+
- `d`: Dictionary containing the fields of the struct [`NormaliserOnline`](@ref).
96
+
- `device`: Device where the Normaliser should be loaded (see [Lux GPU Management](https://lux.csail.mit.edu/dev/manual/gpu_management#gpu-management)).
0 commit comments