Skip to content

Commit 66914f1

Browse files
vchuravylchristmbenegee
committed
Use Adapt.jl to change storage and element type
In order to eventually support GPU computation we need to use Adapt.jl to allow GPU backend packages to swap out host-array types like `CuArray` with device-side types like `CuDeviceArray`. Additionally this will allow us to change the element type of a simulation by using `adapt(Array{Float32}`. Co-authored-by: Lars Christmann <[email protected]> Co-authored-by: Benedict Geihe <[email protected]>
1 parent 6349413 commit 66914f1

File tree

9 files changed

+521
-132
lines changed

9 files changed

+521
-132
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Michael Schlottke-Lakemper <[email protected]>", "Gregor
44
version = "0.9.15-DEV"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
89
CodeTracking = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
910
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
@@ -64,6 +65,7 @@ TrixiMakieExt = "Makie"
6465
TrixiNLsolveExt = "NLsolve"
6566

6667
[compat]
68+
Adapt = "3.7, 4.0"
6769
Accessors = "0.1.12"
6870
CodeTracking = "1.0.5"
6971
ConstructionBase = "1.3"

src/Trixi.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import SciMLBase: get_du, get_tmp_cache, u_modified!,
4343

4444
using DelimitedFiles: readdlm
4545
using Downloads: Downloads
46+
import Adapt
4647
using CodeTracking: CodeTracking
4748
using ConstructionBase: ConstructionBase
4849
using DiffEqCallbacks: PeriodicCallback, PeriodicCallbackAffect

src/auxiliary/vector_of_arrays.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2+
# Since these FMAs can increase the performance of many numerical algorithms,
3+
# we need to opt-in explicitly.
4+
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5+
@muladd begin
6+
#! format: noindent
7+
8+
# Wraps a Vector of Arrays, forwards `getindex` to the underlying Vector.
9+
# Implements `Adapt.adapt_structure` to allow offloading to the GPU which is
10+
# not possible for a plain Vector of Arrays.
11+
struct VecOfArrays{T <: AbstractArray}
12+
arrays::Vector{T}
13+
end
14+
Base.getindex(v::VecOfArrays, i::Int) = Base.getindex(v.arrays, i)
15+
Base.IndexStyle(v::VecOfArrays) = Base.IndexStyle(v.arrays)
16+
Base.size(v::VecOfArrays) = Base.size(v.arrays)
17+
Base.length(v::VecOfArrays) = Base.length(v.arrays)
18+
Base.eltype(v::VecOfArrays{T}) where {T} = T
19+
function Adapt.adapt_structure(to, v::VecOfArrays)
20+
return [Adapt.adapt(to, arr) for arr in v.arrays] |> VecOfArrays
21+
end
22+
end # @muladd

src/semidiscretization/semidiscretization_hyperbolic.jl

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,19 @@ mutable struct SemidiscretizationHyperbolic{Mesh, Equations, InitialCondition,
3030

3131
function SemidiscretizationHyperbolic{Mesh, Equations, InitialCondition,
3232
BoundaryConditions, SourceTerms, Solver,
33-
Cache}(mesh::Mesh, equations::Equations,
33+
Cache}(mesh::Mesh,
34+
equations::Equations,
3435
initial_condition::InitialCondition,
3536
boundary_conditions::BoundaryConditions,
3637
source_terms::SourceTerms,
3738
solver::Solver,
38-
cache::Cache) where {Mesh, Equations,
39+
cache::Cache,
40+
performance_counter::PerformanceCounter) where {Mesh, Equations,
3941
InitialCondition,
4042
BoundaryConditions,
4143
SourceTerms,
4244
Solver,
4345
Cache}
44-
performance_counter = PerformanceCounter()
45-
4646
new(mesh, equations, initial_condition, boundary_conditions, source_terms,
4747
solver, cache, performance_counter)
4848
end
@@ -74,14 +74,16 @@ function SemidiscretizationHyperbolic(mesh, equations, initial_condition, solver
7474

7575
check_periodicity_mesh_boundary_conditions(mesh, _boundary_conditions)
7676

77+
performance_counter = PerformanceCounter()
78+
7779
SemidiscretizationHyperbolic{typeof(mesh), typeof(equations),
7880
typeof(initial_condition),
7981
typeof(_boundary_conditions), typeof(source_terms),
8082
typeof(solver), typeof(cache)}(mesh, equations,
8183
initial_condition,
8284
_boundary_conditions,
8385
source_terms, solver,
84-
cache)
86+
cache, performance_counter)
8587
end
8688

8789
# Create a new semidiscretization but change some parameters compared to the input.
@@ -103,6 +105,30 @@ function remake(semi::SemidiscretizationHyperbolic; uEltype = real(semi.solver),
103105
source_terms, boundary_conditions, uEltype)
104106
end
105107

108+
function Adapt.adapt_structure(to, semi::SemidiscretizationHyperbolic)
109+
if !(typeof(semi.mesh) <: P4estMesh)
110+
error("Adapt.adapt is only supported for semidiscretizations based on P4estMesh")
111+
end
112+
113+
mesh = semi.mesh
114+
equations = Adapt.adapt_structure(to, semi.equations)
115+
initial_condition = Adapt.adapt_structure(to, semi.initial_condition)
116+
boundary_conditions = Adapt.adapt_structure(to, semi.boundary_conditions)
117+
source_terms = Adapt.adapt_structure(to, semi.source_terms)
118+
solver = Adapt.adapt_structure(to, semi.solver)
119+
cache = Adapt.adapt_structure(to, semi.cache)
120+
performance_counter = semi.performance_counter
121+
122+
SemidiscretizationHyperbolic{typeof(mesh), typeof(equations),
123+
typeof(initial_condition),
124+
typeof(boundary_conditions), typeof(source_terms),
125+
typeof(solver), typeof(cache)}(mesh, equations,
126+
initial_condition,
127+
boundary_conditions,
128+
source_terms, solver,
129+
cache, performance_counter)
130+
end
131+
106132
# general fallback
107133
function digest_boundary_conditions(boundary_conditions, mesh, solver, cache)
108134
boundary_conditions

src/solvers/dgsem/basis_lobatto_legendre.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,31 @@ In particular, not the nodes themselves are returned.
122122

123123
@inline get_nodes(basis::LobattoLegendreBasis) = basis.nodes
124124

125+
function Adapt.adapt_structure(to, basis::LobattoLegendreBasis)
126+
# Do not adapt SVector fields, i.e. nodes, weights and inverse_weights
127+
(; nodes, weights, inverse_weights) = basis
128+
inverse_vandermonde_legendre = Adapt.adapt_structure(to,
129+
basis.inverse_vandermonde_legendre)
130+
boundary_interpolation = basis.boundary_interpolation
131+
derivative_matrix = Adapt.adapt_structure(to, basis.derivative_matrix)
132+
derivative_split = Adapt.adapt_structure(to, basis.derivative_split)
133+
derivative_split_transpose = Adapt.adapt_structure(to,
134+
basis.derivative_split_transpose)
135+
derivative_dhat = Adapt.adapt_structure(to, basis.derivative_dhat)
136+
return LobattoLegendreBasis{real(basis), nnodes(basis), typeof(basis.nodes),
137+
typeof(inverse_vandermonde_legendre),
138+
typeof(boundary_interpolation),
139+
typeof(derivative_matrix)}(nodes,
140+
weights,
141+
inverse_weights,
142+
inverse_vandermonde_legendre,
143+
boundary_interpolation,
144+
derivative_matrix,
145+
derivative_split,
146+
derivative_split_transpose,
147+
derivative_dhat)
148+
end
149+
125150
"""
126151
integrate(f, u, basis::LobattoLegendreBasis)
127152
@@ -209,6 +234,16 @@ end
209234

210235
@inline polydeg(mortar::LobattoLegendreMortarL2) = nnodes(mortar) - 1
211236

237+
function Adapt.adapt_structure(to, mortar::LobattoLegendreMortarL2)
238+
forward_upper = Adapt.adapt_structure(to, mortar.forward_upper)
239+
forward_lower = Adapt.adapt_structure(to, mortar.forward_lower)
240+
reverse_upper = Adapt.adapt_structure(to, mortar.reverse_upper)
241+
reverse_lower = Adapt.adapt_structure(to, mortar.reverse_lower)
242+
return LobattoLegendreMortarL2{real(mortar), nnodes(mortar), typeof(forward_upper),
243+
typeof(reverse_upper)}(forward_upper, forward_lower,
244+
reverse_upper, reverse_lower)
245+
end
246+
212247
# TODO: We can create EC mortars along the lines of the following implementation.
213248
# abstract type AbstractMortarEC{RealT} <: AbstractMortar{RealT} end
214249

0 commit comments

Comments
 (0)