Skip to content

Commit 40b5f58

Browse files
committed
start working on test and use custom adaptor
1 parent a0ba9c1 commit 40b5f58

File tree

10 files changed

+136
-66
lines changed

10 files changed

+136
-66
lines changed

src/Trixi.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import SciMLBase: get_du, get_tmp_cache, u_modified!,
4444

4545
using DelimitedFiles: readdlm
4646
using Downloads: Downloads
47-
using Adapt: Adapt
47+
using Adapt: Adapt, adapt
4848
using CodeTracking: CodeTracking
4949
using ConstructionBase: ConstructionBase
5050
using DiffEqBase: DiffEqBase, get_tstops, get_tstops_array

src/auxiliary/containers.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,4 +346,34 @@ function unsafe_wrap_or_alloc(to, vec, size)
346346
return unsafe_wrap(to, pointer(vec), size)
347347
end
348348
end
349+
350+
struct TrixiAdaptor{Storage, Real} end
351+
352+
function trixi_adapt(storage, real, x)
353+
adapt(TrixiAdaptor{storage, real}(), x)
354+
end
355+
356+
# Custom rules
357+
# 1. handling of StaticArrays
358+
function Adapt.adapt_storage(::TrixiAdaptor{<:Any, Real}, x::StaticArrays.StaticArray{S, T, N}) where {Real,S,T,N}
359+
StaticArrays.similar_type(x, Real)(x)
360+
end
361+
362+
# 2. Handling of Arrays
363+
function Adapt.adapt_storage(::TrixiAdaptor{Storage, Real}, x::AbstractArray{T}) where{Storage, Real, T<:AbstractFloat}
364+
adapt(Storage{Real}, x)
365+
end
366+
367+
function Adapt.adapt_storage(::TrixiAdaptor{Storage, Real}, x::AbstractArray{T}) where {Storage, Real,T<:StaticArrays.StaticArray}
368+
adapt(Storage{StaticArrays.similar_type(T, Real)},x)
369+
end
370+
371+
function Adapt.adapt_storage(::TrixiAdaptor{Storage, Real}, x::AbstractArray) where{Storage, Real}
372+
adapt(Storage, x)
373+
end
374+
375+
# 3. TODO: Should we have a fallback? But that would imply implementing things for NamedTuple again
376+
377+
unsafe_wrap_or_alloc(::TrixiAdaptor{Storage}, vec, size) where {Storage} = unsafe_wrap_or_alloc(Storage, vec, size)
378+
349379
end # @muladd

src/semidiscretization/semidiscretization_hyperbolic.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,6 @@ mutable struct SemidiscretizationHyperbolic{Mesh, Equations, InitialCondition,
2727
solver::Solver
2828
cache::Cache
2929
performance_counter::PerformanceCounter
30-
31-
function SemidiscretizationHyperbolic{Mesh, Equations, InitialCondition,
32-
BoundaryConditions, SourceTerms, Solver,
33-
Cache}(mesh::Mesh,
34-
equations::Equations,
35-
initial_condition::InitialCondition,
36-
boundary_conditions::BoundaryConditions,
37-
source_terms::SourceTerms,
38-
solver::Solver,
39-
cache::Cache,
40-
performance_counter::PerformanceCounter) where {
41-
Mesh,
42-
Equations,
43-
InitialCondition,
44-
BoundaryConditions,
45-
SourceTerms,
46-
Solver,
47-
Cache
48-
}
49-
new(mesh, equations, initial_condition, boundary_conditions, source_terms,
50-
solver, cache, performance_counter)
51-
end
5230
end
5331

5432
"""
@@ -90,6 +68,31 @@ function SemidiscretizationHyperbolic(mesh, equations, initial_condition, solver
9068
performance_counter)
9169
end
9270

71+
@eval Adapt.@adapt_structure(SemidiscretizationHyperbolic)
72+
# function Adapt.adapt_structure(to, semi::SemidiscretizationHyperbolic)
73+
# if !(typeof(semi.mesh) <: P4estMesh)
74+
# error("Adapt.adapt is only supported for semidiscretizations based on P4estMesh")
75+
# end
76+
77+
# mesh = semi.mesh
78+
# equations = adapt(to, semi.equations)
79+
# initial_condition = adapt(to, semi.initial_condition)
80+
# boundary_conditions = adapt(to, semi.boundary_conditions)
81+
# source_terms = adapt(to, semi.source_terms)
82+
# solver = adapt(to, semi.solver)
83+
# cache = adapt(to, semi.cache)
84+
# performance_counter = semi.performance_counter
85+
86+
# SemidiscretizationHyperbolic{typeof(mesh), typeof(equations),
87+
# typeof(initial_condition),
88+
# typeof(boundary_conditions), typeof(source_terms),
89+
# typeof(solver), typeof(cache)}(mesh, equations,
90+
# initial_condition,
91+
# boundary_conditions,
92+
# source_terms, solver,
93+
# cache, performance_counter)
94+
# end
95+
9396
# Create a new semidiscretization but change some parameters compared to the input.
9497
# `Base.similar` follows a related concept but would require us to `copy` the `mesh`,
9598
# which would impact the performance. Instead, `SciMLBase.remake` has exactly the
@@ -109,9 +112,6 @@ function remake(semi::SemidiscretizationHyperbolic; uEltype = real(semi.solver),
109112
source_terms, boundary_conditions, uEltype)
110113
end
111114

112-
# @eval due to @muladd
113-
@eval Adapt.@adapt_structure(SemidiscretizationHyperbolic)
114-
115115
# general fallback
116116
function digest_boundary_conditions(boundary_conditions, mesh, solver, cache)
117117
boundary_conditions

src/solvers/dg.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,8 @@ struct DG{Basis, Mortar, SurfaceIntegral, VolumeIntegral}
400400
volume_integral::VolumeIntegral
401401
end
402402

403+
@eval Adapt.@adapt_structure(DG)
404+
403405
function Base.show(io::IO, dg::DG)
404406
@nospecialize dg # reduce precompilation time
405407

src/solvers/dgsem/basis_lobatto_legendre.jl

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,32 @@ struct LobattoLegendreBasis{RealT <: Real, NNODES,
3434
# negative adjoint wrt the SBP dot product
3535
end
3636

37+
function Adapt.adapt_structure(to, basis::LobattoLegendreBasis)
38+
inverse_vandermonde_legendre = adapt(to, basis.inverse_vandermonde_legendre)
39+
RealT = eltype(inverse_vandermonde_legendre)
40+
41+
nodes = SVector{<:Any, RealT}(basis.nodes)
42+
weights = SVector{<:Any, RealT}(basis.weights)
43+
inverse_weights = SVector{<:Any, RealT}(basis.inverse_weights)
44+
boundary_interpolation = adapt(to, basis.boundary_interpolation)
45+
derivative_matrix = adapt(to, basis.derivative_matrix)
46+
derivative_split = adapt(to, basis.derivative_split)
47+
derivative_split_transpose = adapt(to,basis.derivative_split_transpose)
48+
derivative_dhat = adapt(to, basis.derivative_dhat)
49+
return LobattoLegendreBasis{RealT, nnodes(basis), typeof(nodes),
50+
typeof(inverse_vandermonde_legendre),
51+
typeof(boundary_interpolation),
52+
typeof(derivative_matrix)}(nodes,
53+
weights,
54+
inverse_weights,
55+
inverse_vandermonde_legendre,
56+
boundary_interpolation,
57+
derivative_matrix,
58+
derivative_split,
59+
derivative_split_transpose,
60+
derivative_dhat)
61+
end
62+
3763
function LobattoLegendreBasis(RealT, polydeg::Integer)
3864
nnodes_ = polydeg + 1
3965

@@ -122,9 +148,6 @@ In particular, not the nodes themselves are returned.
122148

123149
@inline get_nodes(basis::LobattoLegendreBasis) = basis.nodes
124150

125-
# @eval due to @muladd
126-
@eval Adapt.@adapt_structure(LobattoLegendreBasis)
127-
128151
"""
129152
integrate(f, u, basis::LobattoLegendreBasis)
130153
@@ -158,6 +181,16 @@ struct LobattoLegendreMortarL2{RealT <: Real, NNODES,
158181
reverse_lower::ReverseMatrix
159182
end
160183

184+
function Adapt.adapt_structure(to, mortar::LobattoLegendreMortarL2)
185+
forward_upper = adapt(to, mortar.forward_upper)
186+
forward_lower = adapt(to, mortar.forward_lower)
187+
reverse_upper = adapt(to, mortar.reverse_upper)
188+
reverse_lower = adapt(to, mortar.reverse_lower)
189+
return LobattoLegendreMortarL2{eltype(forward_upper), nnodes(mortar), typeof(forward_upper),
190+
typeof(reverse_upper)}(forward_upper, forward_lower,
191+
reverse_upper, reverse_lower)
192+
end
193+
161194
function MortarL2(basis::LobattoLegendreBasis)
162195
RealT = real(basis)
163196
nnodes_ = nnodes(basis)
@@ -208,9 +241,6 @@ end
208241

209242
@inline polydeg(mortar::LobattoLegendreMortarL2) = nnodes(mortar) - 1
210243

211-
# @eval due to @muladd
212-
@eval Adapt.@adapt_structure(LobattoLegendreMortarL2)
213-
214244
# TODO: We can create EC mortars along the lines of the following implementation.
215245
# abstract type AbstractMortarEC{RealT} <: AbstractMortar{RealT} end
216246

src/solvers/dgsem_p4est/containers.jl

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,16 @@ end
144144

145145
# Manual adapt_structure since we have aliasing memory
146146
function Adapt.adapt_structure(to,
147-
elements::P4estElementContainer{NDIMS, RealT, uEltype}) where {
148-
NDIMS,
149-
RealT,
150-
uEltype
151-
}
147+
elements::P4estElementContainer{NDIMS}) where {NDIMS}
152148
# Adapt underlying storage
153-
_node_coordinates = Adapt.adapt_structure(to, elements._node_coordinates)
154-
_jacobian_matrix = Adapt.adapt_structure(to, elements._jacobian_matrix)
155-
_contravariant_vectors = Adapt.adapt_structure(to, elements._contravariant_vectors)
156-
_inverse_jacobian = Adapt.adapt_structure(to, elements._inverse_jacobian)
157-
_surface_flux_values = Adapt.adapt_structure(to, elements._surface_flux_values)
149+
_node_coordinates = adapt(to, elements._node_coordinates)
150+
_jacobian_matrix = adapt(to, elements._jacobian_matrix)
151+
_contravariant_vectors = adapt(to, elements._contravariant_vectors)
152+
_inverse_jacobian = adapt(to, elements._inverse_jacobian)
153+
_surface_flux_values = adapt(to, elements._surface_flux_values)
154+
155+
RealT = eltype(_inverse_jacobian)
156+
uEltype = eltype(_surface_flux_values)
158157

159158
# Wrap arrays again
160159
node_coordinates = unsafe_wrap_or_alloc(to, _node_coordinates,
@@ -179,8 +178,7 @@ function Adapt.adapt_structure(to,
179178
typeof(jacobian_matrix), # ArrayNDIMSP3
180179
typeof(_node_coordinates), # VectorRealT
181180
typeof(_surface_flux_values), # VectoruEltype
182-
to,
183-
true)
181+
to)
184182
return P4estElementContainer{new_type_params...}(node_coordinates,
185183
jacobian_matrix,
186184
contravariant_vectors,
@@ -294,9 +292,9 @@ end
294292
# Manual adapt_structure since we have aliasing memory
295293
function Adapt.adapt_structure(to, interfaces::P4estInterfaceContainer)
296294
# Adapt underlying storage
297-
_u = Adapt.adapt_structure(to, interfaces._u)
298-
_neighbor_ids = Adapt.adapt_structure(to, interfaces._neighbor_ids)
299-
_node_indices = Adapt.adapt_structure(to, interfaces._node_indices)
295+
_u = adapt(to, interfaces._u)
296+
_neighbor_ids = adapt(to, interfaces._neighbor_ids)
297+
_node_indices = adapt(to, interfaces._node_indices)
300298
# Wrap arrays again
301299
u = unsafe_wrap_or_alloc(to, _u, size(interfaces.u))
302300
neighbor_ids = unsafe_wrap_or_alloc(to, _neighbor_ids,
@@ -306,12 +304,11 @@ function Adapt.adapt_structure(to, interfaces::P4estInterfaceContainer)
306304

307305
NDIMS = ndims(interfaces)
308306
new_type_params = (NDIMS,
309-
eltype(interfaces),
307+
eltype(_u),
310308
NDIMS + 2,
311309
typeof(u), typeof(neighbor_ids), typeof(node_indices),
312310
typeof(_u), typeof(_neighbor_ids), typeof(_node_indices),
313-
to,
314-
true)
311+
to)
315312
return P4estInterfaceContainer{new_type_params...}(u, neighbor_ids, node_indices,
316313
_u, _neighbor_ids, _node_indices)
317314
end
@@ -436,16 +433,16 @@ end
436433

437434
# Manual adapt_structure since we have aliasing memory
438435
function Adapt.adapt_structure(to, boundaries::P4estBoundaryContainer)
439-
_u = Adapt.adapt_structure(to, boundaries._u)
436+
_u = adapt(to, boundaries._u)
440437
u = unsafe_wrap_or_alloc(to, _u, size(boundaries.u))
441-
neighbor_ids = Adapt.adapt_structure(to, boundaries.neighbor_ids)
442-
node_indices = Adapt.adapt_structure(to, boundaries.node_indices)
438+
neighbor_ids = adapt(to, boundaries.neighbor_ids)
439+
node_indices = adapt(to, boundaries.node_indices)
443440
name = boundaries.name
444441

445442
NDIMS = ndims(boundaries)
446-
return P4estBoundaryContainer{NDIMS, eltype(boundaries), NDIMS + 1, typeof(u),
443+
return P4estBoundaryContainer{NDIMS, eltype(_u), NDIMS + 1, typeof(u),
447444
typeof(neighbor_ids), typeof(node_indices),
448-
typeof(_u), to, true}(u, neighbor_ids, node_indices,
445+
typeof(_u), to}(u, neighbor_ids, node_indices,
449446
name, _u)
450447
end
451448

@@ -574,9 +571,9 @@ end
574571
# Manual adapt_structure since we have aliasing memory
575572
function Adapt.adapt_structure(to, mortars::P4estMortarContainer)
576573
# Adapt underlying storage
577-
_u = Adapt.adapt_structure(to, mortars._u)
578-
_neighbor_ids = Adapt.adapt_structure(to, mortars._neighbor_ids)
579-
_node_indices = Adapt.adapt_structure(to, mortars._node_indices)
574+
_u = adapt(to, mortars._u)
575+
_neighbor_ids = adapt(to, mortars._neighbor_ids)
576+
_node_indices = adapt(to, mortars._node_indices)
580577

581578
# Wrap arrays again
582579
u = unsafe_wrap_or_alloc(to, _u, size(mortars.u))
@@ -585,13 +582,12 @@ function Adapt.adapt_structure(to, mortars::P4estMortarContainer)
585582

586583
NDIMS = ndims(mortars)
587584
new_type_params = (NDIMS,
588-
eltype(mortars),
585+
eltype(_u),
589586
NDIMS + 1,
590587
NDIMS + 3,
591588
typeof(u), typeof(neighbor_ids), typeof(node_indices),
592589
typeof(_u), typeof(_neighbor_ids), typeof(_node_indices),
593-
to,
594-
true)
590+
to)
595591
return P4estMortarContainer{new_type_params...}(u, neighbor_ids, node_indices,
596592
_u, _neighbor_ids, _node_indices)
597593
end

src/solvers/dgsem_p4est/containers_parallel.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ end
9393
# Manual adapt_structure since we have aliasing memory
9494
function Adapt.adapt_structure(to, mpi_interfaces::P4estMPIInterfaceContainer)
9595
# Adapt Vectors and underlying storage
96-
_u = Adapt.adapt_structure(to, mpi_interfaces._u)
97-
local_neighbor_ids = Adapt.adapt_structure(to, mpi_interfaces.local_neighbor_ids)
98-
node_indices = Adapt.adapt_structure(to, mpi_interfaces.node_indices)
99-
local_sides = Adapt.adapt_structure(to, mpi_interfaces.local_sides)
96+
_u = adapt(to, mpi_interfaces._u)
97+
local_neighbor_ids = adapt(to, mpi_interfaces.local_neighbor_ids)
98+
node_indices = adapt(to, mpi_interfaces.node_indices)
99+
local_sides = adapt(to, mpi_interfaces.local_sides)
100100

101101
# Wrap array again
102102
u = unsafe_wrap_or_alloc(to, _u, size(mpi_interfaces.u))
@@ -224,7 +224,7 @@ function Adapt.adapt_structure(to, mpi_mortars::P4estMPIMortarContainer)
224224
# must be redesigned. This skeleton implementation here just exists just
225225
# for compatibility with the rest of the KA.jl solver code
226226

227-
_u = Adapt.adapt_structure(to, mpi_mortars._u)
227+
_u = adapt(to, mpi_mortars._u)
228228
_node_indices = mpi_mortars._node_indices
229229
_normal_directions = mpi_mortars._normal_directions
230230

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
3+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
34
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
45
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
56
Convex = "f65535da-76fb-5f13-bab9-19810c17039a"

test/test_p4est_2d.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ isdir(outdir) && rm(outdir, recursive = true)
2727
du_ode = similar(u_ode)
2828
@test (@allocated Trixi.rhs!(du_ode, u_ode, semi, t)) < 1000
2929
end
30+
semi32 = Trixi.trixi_adapt(Array, Float32, semi)
31+
@test real(semi32.solver) == Float32
32+
@test real(semi32.solver.basis) == Float32
33+
@test real(semi32.solver.mortar) == Float32
34+
@test real(semi32.mesh) == Float32
3035
end
3136

3237
@trixi_testset "elixir_advection_nonconforming_flag.jl" begin

test/test_unstructured_2d.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module TestExamplesUnstructuredMesh2D
22

33
using Test
44
using Trixi
5+
using Adapt
56

67
include("test_trixi.jl")
78

@@ -32,6 +33,11 @@ isdir(outdir) && rm(outdir, recursive = true)
3233
du_ode = similar(u_ode)
3334
@test (@allocated Trixi.rhs!(du_ode, u_ode, semi, t)) < 1000
3435
end
36+
semi32 = Trixi.trixi_adapt(Array, Float32, semi)
37+
@test real(semi32.solver) == Float32
38+
@test real(semi32.solver.basis) == Float32
39+
@test real(semi32.solver.mortar) == Float32
40+
@test_broken real(semi32.mesh) == Float32
3541
end
3642

3743
@trixi_testset "elixir_euler_free_stream.jl" begin

0 commit comments

Comments
 (0)