Skip to content

Commit f80ace8

Browse files
committed
use adapt_structure macro where-ever possible
1 parent a7bc9db commit f80ace8

File tree

5 files changed

+17
-71
lines changed

5 files changed

+17
-71
lines changed

src/solvers/dgsem/basis_lobatto_legendre.jl

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -122,30 +122,7 @@ In particular, not the nodes themselves are returned.
122122

123123
@inline get_nodes(basis::LobattoLegendreBasis) = basis.nodes
124124

125-
function Adapt.adapt_structure(to, basis::LobattoLegendreBasis)
126-
# Do not adapt SVector fields, i.e. nodes, weights and inverse_weights
127-
(; nodes, weights, inverse_weights) = basis
128-
inverse_vandermonde_legendre = Adapt.adapt_structure(to,
129-
basis.inverse_vandermonde_legendre)
130-
boundary_interpolation = basis.boundary_interpolation
131-
derivative_matrix = Adapt.adapt_structure(to, basis.derivative_matrix)
132-
derivative_split = Adapt.adapt_structure(to, basis.derivative_split)
133-
derivative_split_transpose = Adapt.adapt_structure(to,
134-
basis.derivative_split_transpose)
135-
derivative_dhat = Adapt.adapt_structure(to, basis.derivative_dhat)
136-
return LobattoLegendreBasis{real(basis), nnodes(basis), typeof(basis.nodes),
137-
typeof(inverse_vandermonde_legendre),
138-
typeof(boundary_interpolation),
139-
typeof(derivative_matrix)}(nodes,
140-
weights,
141-
inverse_weights,
142-
inverse_vandermonde_legendre,
143-
boundary_interpolation,
144-
derivative_matrix,
145-
derivative_split,
146-
derivative_split_transpose,
147-
derivative_dhat)
148-
end
125+
Adapt.@adapt_structure(LobattoLegendreBasis)
149126

150127
"""
151128
integrate(f, u, basis::LobattoLegendreBasis)
@@ -230,15 +207,7 @@ end
230207

231208
@inline polydeg(mortar::LobattoLegendreMortarL2) = nnodes(mortar) - 1
232209

233-
function Adapt.adapt_structure(to, mortar::LobattoLegendreMortarL2)
234-
forward_upper = Adapt.adapt_structure(to, mortar.forward_upper)
235-
forward_lower = Adapt.adapt_structure(to, mortar.forward_lower)
236-
reverse_upper = Adapt.adapt_structure(to, mortar.reverse_upper)
237-
reverse_lower = Adapt.adapt_structure(to, mortar.reverse_lower)
238-
return LobattoLegendreMortarL2{real(mortar), nnodes(mortar), typeof(forward_upper),
239-
typeof(reverse_upper)}(forward_upper, forward_lower,
240-
reverse_upper, reverse_lower)
241-
end
210+
Adapt.@adapt_structure(LobattoLegendreMortarL2)
242211

243212
# TODO: We can create EC mortars along the lines of the following implementation.
244213
# abstract type AbstractMortarEC{RealT} <: AbstractMortar{RealT} end

src/solvers/dgsem_p4est/containers.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ end
147147
function KernelAbstractions.get_backend(elements::P4estElementContainer)
148148
return KernelAbstractions.get_backend(elements.node_coordinates)
149149
end
150+
# Adapt.@adapt_structure(P4estElementContainer)
150151
function Adapt.adapt_structure(to,
151152
elements::P4estElementContainer{NDIMS, RealT, uEltype}) where {
152153
NDIMS,
@@ -159,7 +160,7 @@ function Adapt.adapt_structure(to,
159160
_contravariant_vectors = Adapt.adapt_structure(to, elements._contravariant_vectors)
160161
_inverse_jacobian = Adapt.adapt_structure(to, elements._inverse_jacobian)
161162
_surface_flux_values = Adapt.adapt_structure(to, elements._surface_flux_values)
162-
163+
163164
# Wrap arrays again
164165
node_coordinates = unsafe_wrap_or_alloc(to, _node_coordinates,
165166
size(elements.node_coordinates))
@@ -300,15 +301,18 @@ end
300301
function KernelAbstractions.get_backend(interfaces::P4estInterfaceContainer)
301302
return KernelAbstractions.get_backend(interfaces.u)
302303
end
304+
# Adapt.@adapt_structure(P4estInterfaceContainer)
303305
function Adapt.adapt_structure(to, interfaces::P4estInterfaceContainer)
304306
# Adapt underlying storage
305307
_u = Adapt.adapt_structure(to, interfaces._u)
306308
_neighbor_ids = Adapt.adapt_structure(to, interfaces._neighbor_ids)
307309
_node_indices = Adapt.adapt_structure(to, interfaces._node_indices)
308310
# Wrap arrays again
309311
u = unsafe_wrap_or_alloc(to, _u, size(interfaces.u))
310-
neighbor_ids = unsafe_wrap_or_alloc(to, _neighbor_ids, size(interfaces.neighbor_ids))
311-
node_indices = unsafe_wrap_or_alloc(to, _node_indices, size(interfaces.node_indices))
312+
neighbor_ids = unsafe_wrap_or_alloc(to, _neighbor_ids,
313+
size(interfaces.neighbor_ids))
314+
node_indices = unsafe_wrap_or_alloc(to, _node_indices,
315+
size(interfaces.node_indices))
312316

313317
NDIMS = ndims(interfaces)
314318
new_type_params = (NDIMS,
@@ -452,7 +456,7 @@ function Adapt.adapt_structure(to, boundaries::P4estBoundaryContainer)
452456
neighbor_ids = Adapt.adapt_structure(to, boundaries.neighbor_ids)
453457
node_indices = Adapt.adapt_structure(to, boundaries.node_indices)
454458
name = boundaries.name
455-
459+
456460
NDIMS = ndims(boundaries)
457461
return P4estBoundaryContainer{NDIMS, eltype(boundaries), NDIMS + 1, typeof(u),
458462
typeof(neighbor_ids), typeof(node_indices),
@@ -587,6 +591,7 @@ end
587591
function KernelAbstractions.get_backend(mortars::P4estMortarContainer)
588592
return KernelAbstractions.get_backend(mortars.u)
589593
end
594+
# Adapt.@adapt_structure P4estMortarContainer
590595
function Adapt.adapt_structure(to, mortars::P4estMortarContainer)
591596
# Adapt underlying storage
592597
_u = Adapt.adapt_structure(to, mortars._u)
@@ -598,7 +603,6 @@ function Adapt.adapt_structure(to, mortars::P4estMortarContainer)
598603
neighbor_ids = unsafe_wrap_or_alloc(to, _neighbor_ids, size(mortars.neighbor_ids))
599604
node_indices = unsafe_wrap_or_alloc(to, _node_indices, size(mortars.node_indices))
600605

601-
602606
NDIMS = ndims(mortars)
603607
new_type_params = (NDIMS,
604608
eltype(mortars),

src/solvers/dgsem_p4est/containers_parallel.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ end
9494
function KernelAbstractions.get_backend(mpi_interfaces::P4estMPIInterfaceContainer)
9595
return KernelAbstractions.get_backend(mpi_interfaces.u)
9696
end
97+
# Adapt.@adapt_structure(P4estMPIInterfaceContainer)
9798
function Adapt.adapt_structure(to, mpi_interfaces::P4estMPIInterfaceContainer)
9899
# Adapt Vectors and underlying storage
99100
_u = Adapt.adapt_structure(to, mpi_interfaces._u)
@@ -201,7 +202,7 @@ function init_mpi_mortars(mesh::Union{ParallelP4estMesh, ParallelT8codeMesh}, eq
201202

202203
mpi_mortars = P4estMPIMortarContainer{NDIMS, uEltype, RealT, NDIMS + 1, NDIMS + 2,
203204
NDIMS + 3, typeof(u),
204-
typeof(_u),
205+
typeof(_u),
205206
Array, false}(u, local_neighbor_ids,
206207
local_neighbor_positions,
207208
node_indices, normal_directions,

src/solvers/dgsem_p4est/dg_parallel.jl

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
@muladd begin
66
#! format: noindent
77

8-
mutable struct P4estMPICache{BufferType <: DenseVector, VecInt <: DenseVector{<:Integer}}
8+
mutable struct P4estMPICache{BufferType <: DenseVector,
9+
VecInt <: DenseVector{<:Integer}}
910
mpi_neighbor_ranks::Vector{Int}
1011
mpi_neighbor_interfaces::VecOfArrays{VecInt}
1112
mpi_neighbor_mortars::VecOfArrays{VecInt}
@@ -46,27 +47,7 @@ end
4647

4748
@inline Base.eltype(::P4estMPICache{BufferType}) where {BufferType} = eltype(BufferType)
4849

49-
function Adapt.adapt_structure(to, mpi_cache::P4estMPICache)
50-
mpi_neighbor_ranks = mpi_cache.mpi_neighbor_ranks
51-
mpi_neighbor_interfaces = Adapt.adapt_structure(to, mpi_cache.mpi_neighbor_interfaces)
52-
mpi_neighbor_mortars = Adapt.adapt_structure(to, mpi_cache.mpi_neighbor_mortars)
53-
mpi_send_buffers = Adapt.adapt_structure(to, mpi_cache.mpi_send_buffers)
54-
mpi_recv_buffers = Adapt.adapt_structure(to, mpi_cache.mpi_recv_buffers)
55-
mpi_send_requests = mpi_cache.mpi_send_requests
56-
mpi_recv_requests = mpi_cache.mpi_recv_requests
57-
n_elements_by_rank = mpi_cache.n_elements_by_rank
58-
n_elements_global = mpi_cache.n_elements_global
59-
first_element_global_id = mpi_cache.first_element_global_id
60-
61-
@assert eltype(mpi_send_buffers) == eltype(mpi_recv_buffers)
62-
BufferType = eltype(mpi_send_buffers)
63-
VecInt = eltype(mpi_neighbor_interfaces)
64-
return P4estMPICache{BufferType, VecInt}(mpi_neighbor_ranks, mpi_neighbor_interfaces,
65-
mpi_neighbor_mortars, mpi_send_buffers,
66-
mpi_recv_buffers, mpi_send_requests,
67-
mpi_recv_requests, n_elements_by_rank,
68-
n_elements_global, first_element_global_id)
69-
end
50+
Adapt.@adapt_structure(P4estMPICache)
7051

7152
##
7253
# Note that the code in `start_mpi_send`/`finish_mpi_receive!` is sensitive to inference on (at least) Julia 1.10.

src/solvers/dgsem_unstructured/sort_boundary_conditions.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,5 @@ function initialize!(boundary_types_container::UnstructuredSortedBoundaryTypes{N
122122
return boundary_types_container
123123
end
124124

125-
function Adapt.adapt_structure(to, bcs::UnstructuredSortedBoundaryTypes)
126-
boundary_indices = Adapt.adapt_structure(to, bcs.boundary_indices)
127-
n_boundary_types = length(bcs.boundary_condition_types)
128-
return UnstructuredSortedBoundaryTypes{n_boundary_types,
129-
typeof(bcs.boundary_condition_types),
130-
eltype(boundary_indices)}(bcs.boundary_condition_types,
131-
boundary_indices,
132-
bcs.boundary_dictionary,
133-
bcs.boundary_symbol_indices)
134-
end
125+
Adapt.@adapt_structure(UnstructuredSortedBoundaryTypes)
135126
end # @muladd

0 commit comments

Comments
 (0)