Skip to content

Commit a3e4530

Browse files
committed
start working on test and use custom adaptor
1 parent 36327ce commit a3e4530

File tree

10 files changed

+136
-66
lines changed

10 files changed

+136
-66
lines changed

src/Trixi.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import SciMLBase: get_du, get_tmp_cache, u_modified!,
4444

4545
using DelimitedFiles: readdlm
4646
using Downloads: Downloads
47-
using Adapt: Adapt
47+
using Adapt: Adapt, adapt
4848
using CodeTracking: CodeTracking
4949
using ConstructionBase: ConstructionBase
5050
using DiffEqBase: DiffEqBase, get_tstops, get_tstops_array

src/auxiliary/containers.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,4 +346,34 @@ function unsafe_wrap_or_alloc(to, vec, size)
346346
return unsafe_wrap(to, pointer(vec), size)
347347
end
348348
end
349+
350+
struct TrixiAdaptor{Storage, Real} end
351+
352+
function trixi_adapt(storage, real, x)
353+
adapt(TrixiAdaptor{storage, real}(), x)
354+
end
355+
356+
# Custom rules
357+
# 1. handling of StaticArrays
358+
function Adapt.adapt_storage(::TrixiAdaptor{<:Any, Real}, x::StaticArrays.StaticArray{S, T, N}) where {Real,S,T,N}
359+
StaticArrays.similar_type(x, Real)(x)
360+
end
361+
362+
# 2. Handling of Arrays
363+
function Adapt.adapt_storage(::TrixiAdaptor{Storage, Real}, x::AbstractArray{T}) where{Storage, Real, T<:AbstractFloat}
364+
adapt(Storage{Real}, x)
365+
end
366+
367+
function Adapt.adapt_storage(::TrixiAdaptor{Storage, Real}, x::AbstractArray{T}) where {Storage, Real,T<:StaticArrays.StaticArray}
368+
adapt(Storage{StaticArrays.similar_type(T, Real)},x)
369+
end
370+
371+
function Adapt.adapt_storage(::TrixiAdaptor{Storage, Real}, x::AbstractArray) where{Storage, Real}
372+
adapt(Storage, x)
373+
end
374+
375+
# 3. TODO: Should we have a fallback? But that would imply implementing things for NamedTuple again
376+
377+
unsafe_wrap_or_alloc(::TrixiAdaptor{Storage}, vec, size) where {Storage} = unsafe_wrap_or_alloc(Storage, vec, size)
378+
349379
end # @muladd

src/semidiscretization/semidiscretization_hyperbolic.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,6 @@ mutable struct SemidiscretizationHyperbolic{Mesh, Equations, InitialCondition,
2727
solver::Solver
2828
cache::Cache
2929
performance_counter::PerformanceCounter
30-
31-
function SemidiscretizationHyperbolic{Mesh, Equations, InitialCondition,
32-
BoundaryConditions, SourceTerms, Solver,
33-
Cache}(mesh::Mesh,
34-
equations::Equations,
35-
initial_condition::InitialCondition,
36-
boundary_conditions::BoundaryConditions,
37-
source_terms::SourceTerms,
38-
solver::Solver,
39-
cache::Cache,
40-
performance_counter::PerformanceCounter) where {
41-
Mesh,
42-
Equations,
43-
InitialCondition,
44-
BoundaryConditions,
45-
SourceTerms,
46-
Solver,
47-
Cache
48-
}
49-
new(mesh, equations, initial_condition, boundary_conditions, source_terms,
50-
solver, cache, performance_counter)
51-
end
5230
end
5331

5432
"""
@@ -90,6 +68,31 @@ function SemidiscretizationHyperbolic(mesh, equations, initial_condition, solver
9068
performance_counter)
9169
end
9270

71+
@eval Adapt.@adapt_structure(SemidiscretizationHyperbolic)
72+
# function Adapt.adapt_structure(to, semi::SemidiscretizationHyperbolic)
73+
# if !(typeof(semi.mesh) <: P4estMesh)
74+
# error("Adapt.adapt is only supported for semidiscretizations based on P4estMesh")
75+
# end
76+
77+
# mesh = semi.mesh
78+
# equations = adapt(to, semi.equations)
79+
# initial_condition = adapt(to, semi.initial_condition)
80+
# boundary_conditions = adapt(to, semi.boundary_conditions)
81+
# source_terms = adapt(to, semi.source_terms)
82+
# solver = adapt(to, semi.solver)
83+
# cache = adapt(to, semi.cache)
84+
# performance_counter = semi.performance_counter
85+
86+
# SemidiscretizationHyperbolic{typeof(mesh), typeof(equations),
87+
# typeof(initial_condition),
88+
# typeof(boundary_conditions), typeof(source_terms),
89+
# typeof(solver), typeof(cache)}(mesh, equations,
90+
# initial_condition,
91+
# boundary_conditions,
92+
# source_terms, solver,
93+
# cache, performance_counter)
94+
# end
95+
9396
# Create a new semidiscretization but change some parameters compared to the input.
9497
# `Base.similar` follows a related concept but would require us to `copy` the `mesh`,
9598
# which would impact the performance. Instead, `SciMLBase.remake` has exactly the
@@ -109,9 +112,6 @@ function remake(semi::SemidiscretizationHyperbolic; uEltype = real(semi.solver),
109112
source_terms, boundary_conditions, uEltype)
110113
end
111114

112-
# @eval due to @muladd
113-
@eval Adapt.@adapt_structure(SemidiscretizationHyperbolic)
114-
115115
# general fallback
116116
function digest_boundary_conditions(boundary_conditions, mesh, solver, cache)
117117
boundary_conditions

src/solvers/dg.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,8 @@ struct DG{Basis, Mortar, SurfaceIntegral, VolumeIntegral}
400400
volume_integral::VolumeIntegral
401401
end
402402

403+
@eval Adapt.@adapt_structure(DG)
404+
403405
function Base.show(io::IO, dg::DG)
404406
@nospecialize dg # reduce precompilation time
405407

src/solvers/dgsem/basis_lobatto_legendre.jl

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,32 @@ struct LobattoLegendreBasis{RealT <: Real, NNODES,
3434
# negative adjoint wrt the SBP dot product
3535
end
3636

37+
function Adapt.adapt_structure(to, basis::LobattoLegendreBasis)
38+
inverse_vandermonde_legendre = adapt(to, basis.inverse_vandermonde_legendre)
39+
RealT = eltype(inverse_vandermonde_legendre)
40+
41+
nodes = SVector{<:Any, RealT}(basis.nodes)
42+
weights = SVector{<:Any, RealT}(basis.weights)
43+
inverse_weights = SVector{<:Any, RealT}(basis.inverse_weights)
44+
boundary_interpolation = adapt(to, basis.boundary_interpolation)
45+
derivative_matrix = adapt(to, basis.derivative_matrix)
46+
derivative_split = adapt(to, basis.derivative_split)
47+
derivative_split_transpose = adapt(to,basis.derivative_split_transpose)
48+
derivative_dhat = adapt(to, basis.derivative_dhat)
49+
return LobattoLegendreBasis{RealT, nnodes(basis), typeof(nodes),
50+
typeof(inverse_vandermonde_legendre),
51+
typeof(boundary_interpolation),
52+
typeof(derivative_matrix)}(nodes,
53+
weights,
54+
inverse_weights,
55+
inverse_vandermonde_legendre,
56+
boundary_interpolation,
57+
derivative_matrix,
58+
derivative_split,
59+
derivative_split_transpose,
60+
derivative_dhat)
61+
end
62+
3763
function LobattoLegendreBasis(RealT, polydeg::Integer)
3864
nnodes_ = polydeg + 1
3965

@@ -122,9 +148,6 @@ In particular, not the nodes themselves are returned.
122148

123149
@inline get_nodes(basis::LobattoLegendreBasis) = basis.nodes
124150

125-
# @eval due to @muladd
126-
@eval Adapt.@adapt_structure(LobattoLegendreBasis)
127-
128151
"""
129152
integrate(f, u, basis::LobattoLegendreBasis)
130153
@@ -158,6 +181,16 @@ struct LobattoLegendreMortarL2{RealT <: Real, NNODES,
158181
reverse_lower::ReverseMatrix
159182
end
160183

184+
function Adapt.adapt_structure(to, mortar::LobattoLegendreMortarL2)
185+
forward_upper = adapt(to, mortar.forward_upper)
186+
forward_lower = adapt(to, mortar.forward_lower)
187+
reverse_upper = adapt(to, mortar.reverse_upper)
188+
reverse_lower = adapt(to, mortar.reverse_lower)
189+
return LobattoLegendreMortarL2{eltype(forward_upper), nnodes(mortar), typeof(forward_upper),
190+
typeof(reverse_upper)}(forward_upper, forward_lower,
191+
reverse_upper, reverse_lower)
192+
end
193+
161194
function MortarL2(basis::LobattoLegendreBasis)
162195
RealT = real(basis)
163196
nnodes_ = nnodes(basis)
@@ -208,9 +241,6 @@ end
208241

209242
@inline polydeg(mortar::LobattoLegendreMortarL2) = nnodes(mortar) - 1
210243

211-
# @eval due to @muladd
212-
@eval Adapt.@adapt_structure(LobattoLegendreMortarL2)
213-
214244
# TODO: We can create EC mortars along the lines of the following implementation.
215245
# abstract type AbstractMortarEC{RealT} <: AbstractMortar{RealT} end
216246

src/solvers/dgsem_p4est/containers.jl

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,16 @@ end
145145

146146
# Manual adapt_structure since we have aliasing memory
147147
function Adapt.adapt_structure(to,
148-
elements::P4estElementContainer{NDIMS, RealT, uEltype}) where {
149-
NDIMS,
150-
RealT,
151-
uEltype
152-
}
148+
elements::P4estElementContainer{NDIMS}) where {NDIMS}
153149
# Adapt underlying storage
154-
_node_coordinates = Adapt.adapt_structure(to, elements._node_coordinates)
155-
_jacobian_matrix = Adapt.adapt_structure(to, elements._jacobian_matrix)
156-
_contravariant_vectors = Adapt.adapt_structure(to, elements._contravariant_vectors)
157-
_inverse_jacobian = Adapt.adapt_structure(to, elements._inverse_jacobian)
158-
_surface_flux_values = Adapt.adapt_structure(to, elements._surface_flux_values)
150+
_node_coordinates = adapt(to, elements._node_coordinates)
151+
_jacobian_matrix = adapt(to, elements._jacobian_matrix)
152+
_contravariant_vectors = adapt(to, elements._contravariant_vectors)
153+
_inverse_jacobian = adapt(to, elements._inverse_jacobian)
154+
_surface_flux_values = adapt(to, elements._surface_flux_values)
155+
156+
RealT = eltype(_inverse_jacobian)
157+
uEltype = eltype(_surface_flux_values)
159158

160159
# Wrap arrays again
161160
node_coordinates = unsafe_wrap_or_alloc(to, _node_coordinates,
@@ -180,8 +179,7 @@ function Adapt.adapt_structure(to,
180179
typeof(jacobian_matrix), # ArrayNDIMSP3
181180
typeof(_node_coordinates), # VectorRealT
182181
typeof(_surface_flux_values), # VectoruEltype
183-
to,
184-
true)
182+
to)
185183
return P4estElementContainer{new_type_params...}(node_coordinates,
186184
jacobian_matrix,
187185
contravariant_vectors,
@@ -296,9 +294,9 @@ end
296294
# Manual adapt_structure since we have aliasing memory
297295
function Adapt.adapt_structure(to, interfaces::P4estInterfaceContainer)
298296
# Adapt underlying storage
299-
_u = Adapt.adapt_structure(to, interfaces._u)
300-
_neighbor_ids = Adapt.adapt_structure(to, interfaces._neighbor_ids)
301-
_node_indices = Adapt.adapt_structure(to, interfaces._node_indices)
297+
_u = adapt(to, interfaces._u)
298+
_neighbor_ids = adapt(to, interfaces._neighbor_ids)
299+
_node_indices = adapt(to, interfaces._node_indices)
302300
# Wrap arrays again
303301
u = unsafe_wrap_or_alloc(to, _u, size(interfaces.u))
304302
neighbor_ids = unsafe_wrap_or_alloc(to, _neighbor_ids,
@@ -308,12 +306,11 @@ function Adapt.adapt_structure(to, interfaces::P4estInterfaceContainer)
308306

309307
NDIMS = ndims(interfaces)
310308
new_type_params = (NDIMS,
311-
eltype(interfaces),
309+
eltype(_u),
312310
NDIMS + 2,
313311
typeof(u), typeof(neighbor_ids), typeof(node_indices),
314312
typeof(_u), typeof(_neighbor_ids), typeof(_node_indices),
315-
to,
316-
true)
313+
to)
317314
return P4estInterfaceContainer{new_type_params...}(u, neighbor_ids, node_indices,
318315
_u, _neighbor_ids, _node_indices)
319316
end
@@ -439,16 +436,16 @@ end
439436

440437
# Manual adapt_structure since we have aliasing memory
441438
function Adapt.adapt_structure(to, boundaries::P4estBoundaryContainer)
442-
_u = Adapt.adapt_structure(to, boundaries._u)
439+
_u = adapt(to, boundaries._u)
443440
u = unsafe_wrap_or_alloc(to, _u, size(boundaries.u))
444-
neighbor_ids = Adapt.adapt_structure(to, boundaries.neighbor_ids)
445-
node_indices = Adapt.adapt_structure(to, boundaries.node_indices)
441+
neighbor_ids = adapt(to, boundaries.neighbor_ids)
442+
node_indices = adapt(to, boundaries.node_indices)
446443
name = boundaries.name
447444

448445
NDIMS = ndims(boundaries)
449-
return P4estBoundaryContainer{NDIMS, eltype(boundaries), NDIMS + 1, typeof(u),
446+
return P4estBoundaryContainer{NDIMS, eltype(_u), NDIMS + 1, typeof(u),
450447
typeof(neighbor_ids), typeof(node_indices),
451-
typeof(_u), to, true}(u, neighbor_ids, node_indices,
448+
typeof(_u), to}(u, neighbor_ids, node_indices,
452449
name, _u)
453450
end
454451

@@ -578,9 +575,9 @@ end
578575
# Manual adapt_structure since we have aliasing memory
579576
function Adapt.adapt_structure(to, mortars::P4estMortarContainer)
580577
# Adapt underlying storage
581-
_u = Adapt.adapt_structure(to, mortars._u)
582-
_neighbor_ids = Adapt.adapt_structure(to, mortars._neighbor_ids)
583-
_node_indices = Adapt.adapt_structure(to, mortars._node_indices)
578+
_u = adapt(to, mortars._u)
579+
_neighbor_ids = adapt(to, mortars._neighbor_ids)
580+
_node_indices = adapt(to, mortars._node_indices)
584581

585582
# Wrap arrays again
586583
u = unsafe_wrap_or_alloc(to, _u, size(mortars.u))
@@ -589,13 +586,12 @@ function Adapt.adapt_structure(to, mortars::P4estMortarContainer)
589586

590587
NDIMS = ndims(mortars)
591588
new_type_params = (NDIMS,
592-
eltype(mortars),
589+
eltype(_u),
593590
NDIMS + 1,
594591
NDIMS + 3,
595592
typeof(u), typeof(neighbor_ids), typeof(node_indices),
596593
typeof(_u), typeof(_neighbor_ids), typeof(_node_indices),
597-
to,
598-
true)
594+
to)
599595
return P4estMortarContainer{new_type_params...}(u, neighbor_ids, node_indices,
600596
_u, _neighbor_ids, _node_indices)
601597
end

src/solvers/dgsem_p4est/containers_parallel.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ end
9393
# Manual adapt_structure since we have aliasing memory
9494
function Adapt.adapt_structure(to, mpi_interfaces::P4estMPIInterfaceContainer)
9595
# Adapt Vectors and underlying storage
96-
_u = Adapt.adapt_structure(to, mpi_interfaces._u)
97-
local_neighbor_ids = Adapt.adapt_structure(to, mpi_interfaces.local_neighbor_ids)
98-
node_indices = Adapt.adapt_structure(to, mpi_interfaces.node_indices)
99-
local_sides = Adapt.adapt_structure(to, mpi_interfaces.local_sides)
96+
_u = adapt(to, mpi_interfaces._u)
97+
local_neighbor_ids = adapt(to, mpi_interfaces.local_neighbor_ids)
98+
node_indices = adapt(to, mpi_interfaces.node_indices)
99+
local_sides = adapt(to, mpi_interfaces.local_sides)
100100

101101
# Wrap array again
102102
u = unsafe_wrap_or_alloc(to, _u, size(mpi_interfaces.u))
@@ -224,7 +224,7 @@ function Adapt.adapt_structure(to, mpi_mortars::P4estMPIMortarContainer)
224224
# must be redesigned. This skeleton implementation here just exists just
225225
# for compatibility with the rest of the KA.jl solver code
226226

227-
_u = Adapt.adapt_structure(to, mpi_mortars._u)
227+
_u = adapt(to, mpi_mortars._u)
228228
_node_indices = mpi_mortars._node_indices
229229
_normal_directions = mpi_mortars._normal_directions
230230

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
3+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
34
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
45
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
56
Convex = "f65535da-76fb-5f13-bab9-19810c17039a"

test/test_p4est_2d.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ isdir(outdir) && rm(outdir, recursive = true)
2727
du_ode = similar(u_ode)
2828
@test (@allocated Trixi.rhs!(du_ode, u_ode, semi, t)) < 1000
2929
end
30+
semi32 = Trixi.trixi_adapt(Array, Float32, semi)
31+
@test real(semi32.solver) == Float32
32+
@test real(semi32.solver.basis) == Float32
33+
@test real(semi32.solver.mortar) == Float32
34+
@test real(semi32.mesh) == Float32
3035
end
3136

3237
@trixi_testset "elixir_advection_nonconforming_flag.jl" begin

test/test_unstructured_2d.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module TestExamplesUnstructuredMesh2D
22

33
using Test
44
using Trixi
5+
using Adapt
56

67
include("test_trixi.jl")
78

@@ -32,6 +33,11 @@ isdir(outdir) && rm(outdir, recursive = true)
3233
du_ode = similar(u_ode)
3334
@test (@allocated Trixi.rhs!(du_ode, u_ode, semi, t)) < 1000
3435
end
36+
semi32 = Trixi.trixi_adapt(Array, Float32, semi)
37+
@test real(semi32.solver) == Float32
38+
@test real(semi32.solver.basis) == Float32
39+
@test real(semi32.solver.mortar) == Float32
40+
@test_broken real(semi32.mesh) == Float32
3541
end
3642

3743
@trixi_testset "elixir_euler_free_stream.jl" begin

0 commit comments

Comments
 (0)