Skip to content

Commit 0de3051

Browse files
vchuravylchristmbenegee
committed
Use Adapt.jl to change storage and element type
In order to eventually support GPU computation we need to use Adapt.jl to allow GPU backend packages to swap out host-array types like `CuArray` with device-side types like `CuDeviceArray`. Additionally this will allow us to change the element type of a simulation by using `adapt(Array{Float32}`. Co-authored-by: Lars Christmann <[email protected]> Co-authored-by: Benedict Geihe <[email protected]>
1 parent aad2399 commit 0de3051

File tree

14 files changed

+567
-138
lines changed

14 files changed

+567
-138
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Michael Schlottke-Lakemper <[email protected]>", "
44
version = "0.11.12-DEV"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
89
CodeTracking = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
910
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
@@ -63,6 +64,7 @@ TrixiMakieExt = "Makie"
6364
TrixiNLsolveExt = "NLsolve"
6465

6566
[compat]
67+
Adapt = "4"
6668
Accessors = "0.1.36"
6769
CodeTracking = "1.0.5"
6870
ConstructionBase = "1.5"

src/Trixi.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import SciMLBase: get_du, get_tmp_cache, u_modified!,
4444

4545
using DelimitedFiles: readdlm
4646
using Downloads: Downloads
47+
using Adapt: Adapt, adapt
4748
using CodeTracking: CodeTracking
4849
using ConstructionBase: ConstructionBase
4950
using DiffEqBase: DiffEqBase, get_tstops, get_tstops_array
@@ -125,6 +126,7 @@ include("basic_types.jl")
125126

126127
# Include all top-level source files
127128
include("auxiliary/auxiliary.jl")
129+
include("auxiliary/vector_of_arrays.jl")
128130
include("auxiliary/mpi.jl")
129131
include("auxiliary/p4est.jl")
130132
include("auxiliary/t8code.jl")

src/auxiliary/containers.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,4 +314,88 @@ end
314314
function raw_copy!(c::AbstractContainer, from::Int, destination::Int)
315315
raw_copy!(c, c, from, from, destination)
316316
end
317+
318+
# Trixi storage types must implement these two Adapt.jl methods
319+
function Adapt.adapt_structure(to, c::AbstractContainer)
320+
error("Interface: Must implement Adapt.adapt_structure(to, ::$(typeof(c)))")
321+
end
322+
323+
function Adapt.parent_type(C::Type{<:AbstractContainer})
324+
error("Interface: Must implement Adapt.parent_type(::Type{$C}")
325+
end
326+
327+
function Adapt.unwrap_type(C::Type{<:AbstractContainer})
328+
return Adapt.unwrap_type(Adapt.parent_type(C))
329+
end
330+
331+
# TODO: Upstream to Adapt
332+
function storage_type(x)
333+
return storage_type(typeof(x))
334+
end
335+
336+
function storage_type(T::Type)
337+
error("Interface: Must implement storage_type(::Type{$T}")
338+
end
339+
340+
function storage_type(::Type{<:Array})
341+
Array
342+
end
343+
344+
function storage_type(C::Type{<:AbstractContainer})
345+
return storage_type(Adapt.unwrap_type(C))
346+
end
347+
348+
# For some storage backends like CUDA.jl, empty arrays do seem to simply be
349+
# null pointers which can cause `unsafe_wrap` to fail when calling
350+
# Adapt.adapt (ArgumentError, see
351+
# https://github.com/JuliaGPU/CUDA.jl/blob/v5.4.2/src/array.jl#L212-L229).
352+
# To circumvent this, on length zero arrays this allocates
353+
# a separate empty array instead of wrapping.
354+
# However, since zero length arrays are not used in calculations,
355+
# it should be okay if the underlying storage vectors and wrapped arrays
356+
# are not the same as long as they are properly wrapped when `resize!`d etc.
357+
function unsafe_wrap_or_alloc(to, vector, size)
358+
if length(vector) == 0
359+
return similar(vector, size)
360+
else
361+
return unsafe_wrap(to, pointer(vector), size)
362+
end
363+
end
364+
365+
struct TrixiAdaptor{Storage, Real} end
366+
367+
function trixi_adapt(storage, real, x)
368+
adapt(TrixiAdaptor{storage, real}(), x)
369+
end
370+
371+
# Custom rules
372+
# 1. handling of StaticArrays
373+
function Adapt.adapt_storage(::TrixiAdaptor{<:Any, Real},
374+
x::StaticArrays.StaticArray{S, T, N}) where {Real, S, T, N}
375+
StaticArrays.similar_type(x, Real)(x)
376+
end
377+
378+
# 2. Handling of Arrays
379+
function Adapt.adapt_storage(::TrixiAdaptor{Storage, Real},
380+
x::AbstractArray{T}) where {Storage, Real,
381+
T <: AbstractFloat}
382+
adapt(Storage{Real}, x)
383+
end
384+
385+
function Adapt.adapt_storage(::TrixiAdaptor{Storage, Real},
386+
x::AbstractArray{T}) where {Storage, Real,
387+
T <: StaticArrays.StaticArray}
388+
adapt(Storage{StaticArrays.similar_type(T, Real)}, x)
389+
end
390+
391+
function Adapt.adapt_storage(::TrixiAdaptor{Storage, Real},
392+
x::AbstractArray) where {Storage, Real}
393+
adapt(Storage, x)
394+
end
395+
396+
# 3. TODO: Should we have a fallback? But that would imply implementing things for NamedTuple again
397+
398+
function unsafe_wrap_or_alloc(::TrixiAdaptor{Storage}, vec, size) where {Storage}
399+
return unsafe_wrap_or_alloc(Storage, vec, size)
400+
end
317401
end # @muladd

src/auxiliary/vector_of_arrays.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2+
# Since these FMAs can increase the performance of many numerical algorithms,
3+
# we need to opt-in explicitly.
4+
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5+
@muladd begin
6+
#! format: noindent
7+
8+
# Wraps a Vector of Arrays, forwards `getindex` to the underlying Vector.
9+
# Implements `Adapt.adapt_structure` to allow offloading to the GPU which is
10+
# not possible for a plain Vector of Arrays.
11+
struct VecOfArrays{T <: AbstractArray}
12+
arrays::Vector{T}
13+
end
14+
Base.getindex(v::VecOfArrays, i::Int) = Base.getindex(v.arrays, i)
15+
Base.IndexStyle(v::VecOfArrays) = Base.IndexStyle(v.arrays)
16+
Base.size(v::VecOfArrays) = Base.size(v.arrays)
17+
Base.length(v::VecOfArrays) = Base.length(v.arrays)
18+
Base.eltype(v::VecOfArrays{T}) where {T} = T
19+
function Adapt.adapt_structure(to, v::VecOfArrays)
20+
return VecOfArrays([Adapt.adapt(to, arr) for arr in v.arrays])
21+
end
22+
function Adapt.parent_type(::Type{<:VecOfArrays{T}}) where {T}
23+
return T
24+
end
25+
function Adapt.unwrap_type(A::Type{<:VecOfArrays})
26+
Adapt.unwrap_type(Adapt.parent_type(A))
27+
end
28+
function Base.convert(::Type{<:VecOfArrays}, v::Vector{<:AbstractArray})
29+
VecOfArrays(v)
30+
end
31+
end # @muladd

src/semidiscretization/semidiscretization_hyperbolic.jl

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,6 @@ mutable struct SemidiscretizationHyperbolic{Mesh, Equations, InitialCondition,
2727
solver::Solver
2828
cache::Cache
2929
performance_counter::PerformanceCounter
30-
31-
function SemidiscretizationHyperbolic{Mesh, Equations, InitialCondition,
32-
BoundaryConditions, SourceTerms, Solver,
33-
Cache}(mesh::Mesh, equations::Equations,
34-
initial_condition::InitialCondition,
35-
boundary_conditions::BoundaryConditions,
36-
source_terms::SourceTerms,
37-
solver::Solver,
38-
cache::Cache) where {Mesh, Equations,
39-
InitialCondition,
40-
BoundaryConditions,
41-
SourceTerms,
42-
Solver,
43-
Cache}
44-
performance_counter = PerformanceCounter()
45-
46-
new(mesh, equations, initial_condition, boundary_conditions, source_terms,
47-
solver, cache, performance_counter)
48-
end
4930
end
5031

5132
"""
@@ -74,16 +55,22 @@ function SemidiscretizationHyperbolic(mesh, equations, initial_condition, solver
7455

7556
check_periodicity_mesh_boundary_conditions(mesh, _boundary_conditions)
7657

58+
performance_counter = PerformanceCounter()
59+
7760
SemidiscretizationHyperbolic{typeof(mesh), typeof(equations),
7861
typeof(initial_condition),
7962
typeof(_boundary_conditions), typeof(source_terms),
8063
typeof(solver), typeof(cache)}(mesh, equations,
8164
initial_condition,
8265
_boundary_conditions,
8366
source_terms, solver,
84-
cache)
67+
cache,
68+
performance_counter)
8569
end
8670

71+
# @eval due to @muladd
72+
@eval Adapt.@adapt_structure(SemidiscretizationHyperbolic)
73+
8774
# Create a new semidiscretization but change some parameters compared to the input.
8875
# `Base.similar` follows a related concept but would require us to `copy` the `mesh`,
8976
# which would impact the performance. Instead, `SciMLBase.remake` has exactly the

src/solvers/dg.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,9 @@ struct DG{Basis, Mortar, SurfaceIntegral, VolumeIntegral}
400400
volume_integral::VolumeIntegral
401401
end
402402

403+
# @eval due to @muladd
404+
@eval Adapt.@adapt_structure(DG)
405+
403406
function Base.show(io::IO, dg::DG)
404407
@nospecialize dg # reduce precompilation time
405408

src/solvers/dgsem/basis_lobatto_legendre.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,32 @@ struct LobattoLegendreBasis{RealT <: Real, NNODES,
3434
# negative adjoint wrt the SBP dot product
3535
end
3636

37+
function Adapt.adapt_structure(to, basis::LobattoLegendreBasis)
38+
inverse_vandermonde_legendre = adapt(to, basis.inverse_vandermonde_legendre)
39+
RealT = eltype(inverse_vandermonde_legendre)
40+
41+
nodes = SVector{<:Any, RealT}(basis.nodes)
42+
weights = SVector{<:Any, RealT}(basis.weights)
43+
inverse_weights = SVector{<:Any, RealT}(basis.inverse_weights)
44+
boundary_interpolation = adapt(to, basis.boundary_interpolation)
45+
derivative_matrix = adapt(to, basis.derivative_matrix)
46+
derivative_split = adapt(to, basis.derivative_split)
47+
derivative_split_transpose = adapt(to, basis.derivative_split_transpose)
48+
derivative_dhat = adapt(to, basis.derivative_dhat)
49+
return LobattoLegendreBasis{RealT, nnodes(basis), typeof(nodes),
50+
typeof(inverse_vandermonde_legendre),
51+
typeof(boundary_interpolation),
52+
typeof(derivative_matrix)}(nodes,
53+
weights,
54+
inverse_weights,
55+
inverse_vandermonde_legendre,
56+
boundary_interpolation,
57+
derivative_matrix,
58+
derivative_split,
59+
derivative_split_transpose,
60+
derivative_dhat)
61+
end
62+
3763
function LobattoLegendreBasis(RealT, polydeg::Integer)
3864
nnodes_ = polydeg + 1
3965

@@ -155,6 +181,17 @@ struct LobattoLegendreMortarL2{RealT <: Real, NNODES,
155181
reverse_lower::ReverseMatrix
156182
end
157183

184+
function Adapt.adapt_structure(to, mortar::LobattoLegendreMortarL2)
185+
forward_upper = adapt(to, mortar.forward_upper)
186+
forward_lower = adapt(to, mortar.forward_lower)
187+
reverse_upper = adapt(to, mortar.reverse_upper)
188+
reverse_lower = adapt(to, mortar.reverse_lower)
189+
return LobattoLegendreMortarL2{eltype(forward_upper), nnodes(mortar),
190+
typeof(forward_upper),
191+
typeof(reverse_upper)}(forward_upper, forward_lower,
192+
reverse_upper, reverse_lower)
193+
end
194+
158195
function MortarL2(basis::LobattoLegendreBasis)
159196
RealT = real(basis)
160197
nnodes_ = nnodes(basis)

0 commit comments

Comments
 (0)