Skip to content

Commit 3fff1fa

Browse files
committed
use adapt_structure macro where-ever possible
1 parent eed7fb2 commit 3fff1fa

File tree

5 files changed

+17
-71
lines changed

5 files changed

+17
-71
lines changed

src/solvers/dgsem/basis_lobatto_legendre.jl

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -128,30 +128,7 @@ In particular, not the nodes themselves are returned.
128128

129129
@inline get_nodes(basis::LobattoLegendreBasis) = basis.nodes
130130

131-
function Adapt.adapt_structure(to, basis::LobattoLegendreBasis)
132-
# Do not adapt SVector fields, i.e. nodes, weights and inverse_weights
133-
(; nodes, weights, inverse_weights) = basis
134-
inverse_vandermonde_legendre = Adapt.adapt_structure(to,
135-
basis.inverse_vandermonde_legendre)
136-
boundary_interpolation = basis.boundary_interpolation
137-
derivative_matrix = Adapt.adapt_structure(to, basis.derivative_matrix)
138-
derivative_split = Adapt.adapt_structure(to, basis.derivative_split)
139-
derivative_split_transpose = Adapt.adapt_structure(to,
140-
basis.derivative_split_transpose)
141-
derivative_dhat = Adapt.adapt_structure(to, basis.derivative_dhat)
142-
return LobattoLegendreBasis{real(basis), nnodes(basis), typeof(basis.nodes),
143-
typeof(inverse_vandermonde_legendre),
144-
typeof(boundary_interpolation),
145-
typeof(derivative_matrix)}(nodes,
146-
weights,
147-
inverse_weights,
148-
inverse_vandermonde_legendre,
149-
boundary_interpolation,
150-
derivative_matrix,
151-
derivative_split,
152-
derivative_split_transpose,
153-
derivative_dhat)
154-
end
131+
Adapt.@adapt_structure(LobattoLegendreBasis)
155132

156133
"""
157134
integrate(f, u, basis::LobattoLegendreBasis)
@@ -241,15 +218,7 @@ end
241218

242219
@inline polydeg(mortar::LobattoLegendreMortarL2) = nnodes(mortar) - 1
243220

244-
function Adapt.adapt_structure(to, mortar::LobattoLegendreMortarL2)
245-
forward_upper = Adapt.adapt_structure(to, mortar.forward_upper)
246-
forward_lower = Adapt.adapt_structure(to, mortar.forward_lower)
247-
reverse_upper = Adapt.adapt_structure(to, mortar.reverse_upper)
248-
reverse_lower = Adapt.adapt_structure(to, mortar.reverse_lower)
249-
return LobattoLegendreMortarL2{real(mortar), nnodes(mortar), typeof(forward_upper),
250-
typeof(reverse_upper)}(forward_upper, forward_lower,
251-
reverse_upper, reverse_lower)
252-
end
221+
Adapt.@adapt_structure(LobattoLegendreMortarL2)
253222

254223
# TODO: We can create EC mortars along the lines of the following implementation.
255224
# abstract type AbstractMortarEC{RealT} <: AbstractMortar{RealT} end

src/solvers/dgsem_p4est/containers.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ end
146146
function KernelAbstractions.get_backend(elements::P4estElementContainer)
147147
return KernelAbstractions.get_backend(elements.node_coordinates)
148148
end
149+
# Adapt.@adapt_structure(P4estElementContainer)
149150
function Adapt.adapt_structure(to,
150151
elements::P4estElementContainer{NDIMS, RealT, uEltype}) where {
151152
NDIMS,
@@ -158,7 +159,7 @@ function Adapt.adapt_structure(to,
158159
_contravariant_vectors = Adapt.adapt_structure(to, elements._contravariant_vectors)
159160
_inverse_jacobian = Adapt.adapt_structure(to, elements._inverse_jacobian)
160161
_surface_flux_values = Adapt.adapt_structure(to, elements._surface_flux_values)
161-
162+
162163
# Wrap arrays again
163164
node_coordinates = unsafe_wrap_or_alloc(to, _node_coordinates,
164165
size(elements.node_coordinates))
@@ -298,15 +299,18 @@ end
298299
function KernelAbstractions.get_backend(interfaces::P4estInterfaceContainer)
299300
return KernelAbstractions.get_backend(interfaces.u)
300301
end
302+
# Adapt.@adapt_structure(P4estInterfaceContainer)
301303
function Adapt.adapt_structure(to, interfaces::P4estInterfaceContainer)
302304
# Adapt underlying storage
303305
_u = Adapt.adapt_structure(to, interfaces._u)
304306
_neighbor_ids = Adapt.adapt_structure(to, interfaces._neighbor_ids)
305307
_node_indices = Adapt.adapt_structure(to, interfaces._node_indices)
306308
# Wrap arrays again
307309
u = unsafe_wrap_or_alloc(to, _u, size(interfaces.u))
308-
neighbor_ids = unsafe_wrap_or_alloc(to, _neighbor_ids, size(interfaces.neighbor_ids))
309-
node_indices = unsafe_wrap_or_alloc(to, _node_indices, size(interfaces.node_indices))
310+
neighbor_ids = unsafe_wrap_or_alloc(to, _neighbor_ids,
311+
size(interfaces.neighbor_ids))
312+
node_indices = unsafe_wrap_or_alloc(to, _node_indices,
313+
size(interfaces.node_indices))
310314

311315
NDIMS = ndims(interfaces)
312316
new_type_params = (NDIMS,
@@ -449,7 +453,7 @@ function Adapt.adapt_structure(to, boundaries::P4estBoundaryContainer)
449453
neighbor_ids = Adapt.adapt_structure(to, boundaries.neighbor_ids)
450454
node_indices = Adapt.adapt_structure(to, boundaries.node_indices)
451455
name = boundaries.name
452-
456+
453457
NDIMS = ndims(boundaries)
454458
return P4estBoundaryContainer{NDIMS, eltype(boundaries), NDIMS + 1, typeof(u),
455459
typeof(neighbor_ids), typeof(node_indices),
@@ -583,6 +587,7 @@ end
583587
function KernelAbstractions.get_backend(mortars::P4estMortarContainer)
584588
return KernelAbstractions.get_backend(mortars.u)
585589
end
590+
# Adapt.@adapt_structure P4estMortarContainer
586591
function Adapt.adapt_structure(to, mortars::P4estMortarContainer)
587592
# Adapt underlying storage
588593
_u = Adapt.adapt_structure(to, mortars._u)
@@ -594,7 +599,6 @@ function Adapt.adapt_structure(to, mortars::P4estMortarContainer)
594599
neighbor_ids = unsafe_wrap_or_alloc(to, _neighbor_ids, size(mortars.neighbor_ids))
595600
node_indices = unsafe_wrap_or_alloc(to, _node_indices, size(mortars.node_indices))
596601

597-
598602
NDIMS = ndims(mortars)
599603
new_type_params = (NDIMS,
600604
eltype(mortars),

src/solvers/dgsem_p4est/containers_parallel.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ end
9494
function KernelAbstractions.get_backend(mpi_interfaces::P4estMPIInterfaceContainer)
9595
return KernelAbstractions.get_backend(mpi_interfaces.u)
9696
end
97+
# Adapt.@adapt_structure(P4estMPIInterfaceContainer)
9798
function Adapt.adapt_structure(to, mpi_interfaces::P4estMPIInterfaceContainer)
9899
# Adapt Vectors and underlying storage
99100
_u = Adapt.adapt_structure(to, mpi_interfaces._u)
@@ -201,7 +202,7 @@ function init_mpi_mortars(mesh::Union{ParallelP4estMesh, ParallelT8codeMesh}, eq
201202

202203
mpi_mortars = P4estMPIMortarContainer{NDIMS, uEltype, RealT, NDIMS + 1, NDIMS + 2,
203204
NDIMS + 3, typeof(u),
204-
typeof(_u),
205+
typeof(_u),
205206
Array, false}(u, local_neighbor_ids,
206207
local_neighbor_positions,
207208
node_indices, normal_directions,

src/solvers/dgsem_p4est/dg_parallel.jl

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
@muladd begin
66
#! format: noindent
77

8-
mutable struct P4estMPICache{BufferType <: DenseVector, VecInt <: DenseVector{<:Integer}}
8+
mutable struct P4estMPICache{BufferType <: DenseVector,
9+
VecInt <: DenseVector{<:Integer}}
910
mpi_neighbor_ranks::Vector{Int}
1011
mpi_neighbor_interfaces::VecOfArrays{VecInt}
1112
mpi_neighbor_mortars::VecOfArrays{VecInt}
@@ -46,27 +47,7 @@ end
4647

4748
@inline Base.eltype(::P4estMPICache{BufferType}) where {BufferType} = eltype(BufferType)
4849

49-
function Adapt.adapt_structure(to, mpi_cache::P4estMPICache)
50-
mpi_neighbor_ranks = mpi_cache.mpi_neighbor_ranks
51-
mpi_neighbor_interfaces = Adapt.adapt_structure(to, mpi_cache.mpi_neighbor_interfaces)
52-
mpi_neighbor_mortars = Adapt.adapt_structure(to, mpi_cache.mpi_neighbor_mortars)
53-
mpi_send_buffers = Adapt.adapt_structure(to, mpi_cache.mpi_send_buffers)
54-
mpi_recv_buffers = Adapt.adapt_structure(to, mpi_cache.mpi_recv_buffers)
55-
mpi_send_requests = mpi_cache.mpi_send_requests
56-
mpi_recv_requests = mpi_cache.mpi_recv_requests
57-
n_elements_by_rank = mpi_cache.n_elements_by_rank
58-
n_elements_global = mpi_cache.n_elements_global
59-
first_element_global_id = mpi_cache.first_element_global_id
60-
61-
@assert eltype(mpi_send_buffers) == eltype(mpi_recv_buffers)
62-
BufferType = eltype(mpi_send_buffers)
63-
VecInt = eltype(mpi_neighbor_interfaces)
64-
return P4estMPICache{BufferType, VecInt}(mpi_neighbor_ranks, mpi_neighbor_interfaces,
65-
mpi_neighbor_mortars, mpi_send_buffers,
66-
mpi_recv_buffers, mpi_send_requests,
67-
mpi_recv_requests, n_elements_by_rank,
68-
n_elements_global, first_element_global_id)
69-
end
50+
Adapt.@adapt_structure(P4estMPICache)
7051

7152
function start_mpi_send!(mpi_cache::P4estMPICache, mesh, equations, dg, cache)
7253
data_size = nvariables(equations) * nnodes(dg)^(ndims(mesh) - 1)

src/solvers/dgsem_unstructured/sort_boundary_conditions.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,5 @@ function initialize!(boundary_types_container::UnstructuredSortedBoundaryTypes{N
114114
return boundary_types_container
115115
end
116116

117-
function Adapt.adapt_structure(to, bcs::UnstructuredSortedBoundaryTypes)
118-
boundary_indices = Adapt.adapt_structure(to, bcs.boundary_indices)
119-
n_boundary_types = length(bcs.boundary_condition_types)
120-
return UnstructuredSortedBoundaryTypes{n_boundary_types,
121-
typeof(bcs.boundary_condition_types),
122-
eltype(boundary_indices)}(bcs.boundary_condition_types,
123-
boundary_indices,
124-
bcs.boundary_dictionary,
125-
bcs.boundary_symbol_indices)
126-
end
117+
Adapt.@adapt_structure(UnstructuredSortedBoundaryTypes)
127118
end # @muladd

0 commit comments

Comments
 (0)