Skip to content

Commit 6212b13

Browse files
committed
Use contiguous memory layout with PrecomputedNeighborhoodSearch
1 parent d5f2960 commit 6212b13

File tree

5 files changed

+96
-45
lines changed

5 files changed

+96
-45
lines changed

src/cell_lists/cell_lists_util.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ end
1212

1313
# We need the prod() because FullGridCellList's size is a tuple of cells per dimension whereas
1414
# SpatialHashingCellList's size is an Integer for the number of cells in total.
15-
function construct_backend(::Type{<:AbstractCellList}, ::Type{Vector{Vector{T}}}, size,
15+
function construct_backend(::Any, ::Type{Vector{Vector{T}}}, size,
1616
max_points_per_cell) where {T}
1717
return [T[] for _ in 1:prod(size)]
1818
end
1919

20-
function construct_backend(::Type{<:AbstractCellList}, ::Type{DynamicVectorOfVectors{T}},
20+
function construct_backend(::Any, ::Type{DynamicVectorOfVectors{T}},
2121
size,
2222
max_points_per_cell) where {T}
2323
cells = DynamicVectorOfVectors{T}(max_outer_length = prod(size),
@@ -31,7 +31,7 @@ end
3131
# `DynamicVectorOfVectors{T}`, but a type `DynamicVectorOfVectors{T1, T2, T3, T4}`.
3232
# While `A{T} <: A{T1, T2}`, this doesn't hold for the types.
3333
# `Type{A{T}} <: Type{A{T1, T2}}` is NOT true.
34-
function construct_backend(cell_list::Type{<:AbstractCellList},
34+
function construct_backend(cell_list::Any,
3535
::Type{DynamicVectorOfVectors{T1, T2, T3, T4}}, size,
3636
max_points_per_cell) where {T1, T2, T3, T4}
3737
return construct_backend(cell_list, DynamicVectorOfVectors{T1}, size,

src/cell_lists/full_grid.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ See [`copy_neighborhood_search`](@ref) for more details.
2424
- `backend = DynamicVectorOfVectors{Int32}`: Type of the data structure to store the actual
2525
cell lists. Can be
2626
- `Vector{Vector{Int32}}`: Scattered memory, but very memory-efficient.
27-
- `DynamicVectorOfVectors{Int32}`: Contiguous memory, optimizing cache-hits.
27+
- `DynamicVectorOfVectors{Int32}`: Contiguous memory, optimizing cache-hits
28+
and GPU compatible.
2829
- `max_points_per_cell = 100`: Maximum number of points per cell. This will be used to
2930
allocate the `DynamicVectorOfVectors`. It is not used with
3031
the `Vector{Vector{Int32}}` backend.

src/gpu.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
Adapt.@adapt_structure FullGridCellList
1111
Adapt.@adapt_structure SpatialHashingCellList
1212
Adapt.@adapt_structure DynamicVectorOfVectors
13+
Adapt.@adapt_structure GridNeighborhoodSearch
14+
Adapt.@adapt_structure PrecomputedNeighborhoodSearch
1315

1416
# `adapt(CuArray, ::SVector)::SVector`, but `adapt(Array, ::SVector)::Vector`.
1517
# We don't want to change the type of the `SVector` here.
@@ -22,13 +24,3 @@ end
2224
function Adapt.adapt_structure(to::typeof(Array), range::UnitRange)
2325
return range
2426
end
25-
26-
function Adapt.adapt_structure(to, nhs::GridNeighborhoodSearch)
27-
(; search_radius, periodic_box, n_cells, cell_size) = nhs
28-
29-
cell_list = Adapt.adapt_structure(to, nhs.cell_list)
30-
update_buffer = Adapt.adapt_structure(to, nhs.update_buffer)
31-
32-
return GridNeighborhoodSearch(cell_list, search_radius, periodic_box, n_cells,
33-
cell_size, update_buffer, nhs.update_strategy)
34-
end

src/nhs_precomputed.jl

Lines changed: 79 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,71 +22,93 @@ initialization and update.
2222
[`PeriodicBox`](@ref).
2323
- `update_strategy`: Strategy to parallelize `update!` of the internally used
2424
`GridNeighborhoodSearch`. See [`GridNeighborhoodSearch`](@ref)
25-
for available options.
25+
for available options. This is only used for the default value
26+
of `update_neighborhood_search` below.
27+
- `update_neighborhood_search = GridNeighborhoodSearch{NDIMS}(; periodic_box, update_strategy)`:
28+
The neighborhood search used to compute the neighbor lists.
29+
By default, a [`GridNeighborhoodSearch`](@ref) is used.
30+
- `backend = DynamicVectorOfVectors{Int32}`: Type of the data structure to store
31+
the neighbor lists. Can be
32+
- `Vector{Vector{Int32}}`: Scattered memory, but very memory-efficient.
33+
- `DynamicVectorOfVectors{Int32}`: Contiguous memory, optimizing cache-hits
34+
and GPU-compatible.
35+
- `max_neighbors`: Maximum number of neighbors per particle. This will be used to
36+
allocate the `DynamicVectorOfVectors`. It is not used with
37+
other backends. The default is 64 in 2D and 324 in 3D.
2638
"""
27-
struct PrecomputedNeighborhoodSearch{NDIMS, NHS, NL, PB} <: AbstractNeighborhoodSearch
28-
neighborhood_search :: NHS
39+
struct PrecomputedNeighborhoodSearch{NDIMS, NL, ELTYPE, PB, NHS} <:
40+
AbstractNeighborhoodSearch
2941
neighbor_lists :: NL
42+
search_radius :: ELTYPE
3043
periodic_box :: PB
44+
neighborhood_search :: NHS
3145

3246
function PrecomputedNeighborhoodSearch{NDIMS}(; search_radius = 0.0, n_points = 0,
3347
periodic_box = nothing,
34-
update_strategy = nothing) where {NDIMS}
35-
nhs = GridNeighborhoodSearch{NDIMS}(; search_radius, n_points,
36-
periodic_box, update_strategy)
37-
38-
neighbor_lists = Vector{Vector{Int}}()
39-
40-
new{NDIMS, typeof(nhs),
41-
typeof(neighbor_lists),
42-
typeof(periodic_box)}(nhs, neighbor_lists, periodic_box)
48+
update_strategy = nothing,
49+
update_neighborhood_search = GridNeighborhoodSearch{NDIMS}(;
50+
search_radius,
51+
n_points,
52+
periodic_box,
53+
update_strategy),
54+
backend = DynamicVectorOfVectors{Int32},
55+
max_neighbors = 4 * NDIMS^4) where {NDIMS}
56+
neighbor_lists = construct_backend(nothing, backend, n_points, max_neighbors)
57+
58+
new{NDIMS, typeof(neighbor_lists),
59+
typeof(search_radius), typeof(periodic_box),
60+
typeof(update_neighborhood_search)}(neighbor_lists, search_radius,
61+
periodic_box, update_neighborhood_search)
4362
end
4463
end
4564

4665
@inline Base.ndims(::PrecomputedNeighborhoodSearch{NDIMS}) where {NDIMS} = NDIMS
4766

4867
@inline requires_update(::PrecomputedNeighborhoodSearch) = (true, true)
4968

50-
@inline function search_radius(search::PrecomputedNeighborhoodSearch)
51-
return search_radius(search.neighborhood_search)
52-
end
53-
5469
function initialize!(search::PrecomputedNeighborhoodSearch,
5570
x::AbstractMatrix, y::AbstractMatrix;
5671
parallelization_backend = default_backend(x),
5772
eachindex_y = axes(y, 2))
5873
(; neighborhood_search, neighbor_lists) = search
5974

75+
if eachindex_y != axes(y, 2)
76+
error("this neighborhood search does not support inactive points")
77+
end
78+
6079
# Initialize grid NHS
61-
initialize!(neighborhood_search, x, y; eachindex_y, parallelization_backend)
80+
initialize!(neighborhood_search, x, y; parallelization_backend)
6281

6382
initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y,
64-
parallelization_backend, eachindex_y)
83+
parallelization_backend)
84+
85+
return search
6586
end
6687

67-
# WARNING! Experimental feature:
68-
# By default, determine the parallelization backend from the type of `x`.
69-
# Optionally, pass a `KernelAbstractions.Backend` to run the KernelAbstractions.jl code
70-
# on this backend. This can be useful to run GPU kernels on the CPU by passing
71-
# `parallelization_backend = KernelAbstractions.CPU()`, even though `x isa Array`.
7288
function update!(search::PrecomputedNeighborhoodSearch,
7389
x::AbstractMatrix, y::AbstractMatrix;
7490
points_moving = (true, true), parallelization_backend = default_backend(x),
7591
eachindex_y = axes(y, 2))
7692
(; neighborhood_search, neighbor_lists) = search
7793

78-
# Update grid NHS
79-
update!(neighborhood_search, x, y; eachindex_y, points_moving, parallelization_backend)
94+
if eachindex_y != axes(y, 2)
95+
error("this neighborhood search does not support inactive points")
96+
end
97+
98+
# Update the internal neighborhood search
99+
update!(neighborhood_search, x, y; points_moving, parallelization_backend)
80100

81101
# Skip update if both point sets are static
82102
if any(points_moving)
83103
initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y,
84-
parallelization_backend, eachindex_y)
104+
parallelization_backend)
85105
end
106+
107+
return search
86108
end
87109

88110
function initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y,
89-
parallelization_backend, eachindex_y)
111+
parallelization_backend)
90112
# Initialize neighbor lists
91113
empty!(neighbor_lists)
92114
resize!(neighbor_lists, size(x, 2))
@@ -95,8 +117,19 @@ function initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y,
95117
end
96118

97119
# Fill neighbor lists
98-
foreach_point_neighbor(x, y, neighborhood_search; parallelization_backend,
99-
points = eachindex_y) do point, neighbor, _, _
120+
foreach_point_neighbor(x, y, neighborhood_search;
121+
parallelization_backend) do point, neighbor, _, _
122+
push!(neighbor_lists[point], neighbor)
123+
end
124+
end
125+
126+
function initialize_neighbor_lists!(neighbor_lists::DynamicVectorOfVectors,
127+
neighborhood_search, x, y, parallelization_backend)
128+
resize!(neighbor_lists, size(x, 2))
129+
130+
# Fill neighbor lists
131+
foreach_point_neighbor(x, y, neighborhood_search;
132+
parallelization_backend) do point, neighbor, _, _
100133
push!(neighbor_lists[point], neighbor)
101134
end
102135
end
@@ -132,8 +165,23 @@ end
132165

133166
function copy_neighborhood_search(nhs::PrecomputedNeighborhoodSearch,
134167
search_radius, n_points; eachpoint = 1:n_points)
135-
update_strategy_ = nhs.neighborhood_search.update_strategy
168+
update_neighborhood_search = copy_neighborhood_search(nhs.neighborhood_search,
169+
search_radius, n_points;
170+
eachpoint)
171+
max_neighbors = max_inner_length(nhs.neighbor_lists, 4 * ndims(nhs)^4)
136172
return PrecomputedNeighborhoodSearch{ndims(nhs)}(; search_radius, n_points,
137173
periodic_box = nhs.periodic_box,
138-
update_strategy = update_strategy_)
174+
update_neighborhood_search,
175+
backend = typeof(nhs.neighbor_lists),
176+
max_neighbors)
177+
end
178+
179+
# TODO move to `vector_of_vectors.jl`
180+
function max_inner_length(cells::DynamicVectorOfVectors, fallback)
181+
return size(cells.backend, 1)
182+
end
183+
184+
# Fallback when backend is a `Vector{Vector{T}}`. Only used for copying the cell list.
185+
function max_inner_length(::Any, fallback)
186+
return fallback
139187
end

test/neighborhood_search.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
backend = Vector{Vector{Int32}})),
6060
PrecomputedNeighborhoodSearch{NDIMS}(; search_radius, n_points,
6161
periodic_box = periodic_boxes[i]),
62+
PrecomputedNeighborhoodSearch{NDIMS}(; search_radius, n_points,
63+
periodic_box = periodic_boxes[i],
64+
backend = Vector{Vector{Int32}}),
6265
GridNeighborhoodSearch{NDIMS}(; search_radius, n_points,
6366
periodic_box = periodic_boxes[i],
6467
cell_list = SpatialHashingCellList{NDIMS}(2 *
@@ -71,6 +74,7 @@
7174
"`GridNeighborhoodSearch` with `FullGridCellList` with `DynamicVectorOfVectors`",
7275
"`GridNeighborhoodSearch` with `FullGridCellList` with `Vector{Vector}`",
7376
"`PrecomputedNeighborhoodSearch`",
77+
"`PrecomputedNeighborhoodSearch` with `Vector{Vector}`",
7478
"`GridNeighborhoodSearch` with `SpatialHashingCellList`"
7579
]
7680

@@ -86,6 +90,8 @@
8690
max_corner = periodic_boxes[i].max_corner,
8791
backend = Vector{Vector{Int32}})),
8892
PrecomputedNeighborhoodSearch{NDIMS}(periodic_box = periodic_boxes[i]),
93+
PrecomputedNeighborhoodSearch{NDIMS}(periodic_box = periodic_boxes[i],
94+
backend = Vector{Vector{Int32}}),
8995
GridNeighborhoodSearch{NDIMS}(periodic_box = periodic_boxes[i],
9096
cell_list = SpatialHashingCellList{NDIMS}(2 *
9197
n_points))
@@ -193,6 +199,8 @@
193199
search_radius,
194200
backend = Vector{Vector{Int}})),
195201
PrecomputedNeighborhoodSearch{NDIMS}(; search_radius, n_points),
202+
PrecomputedNeighborhoodSearch{NDIMS}(; search_radius, n_points,
203+
backend = Vector{Vector{Int}}),
196204
GridNeighborhoodSearch{NDIMS}(; search_radius, n_points,
197205
cell_list = SpatialHashingCellList{NDIMS}(2 *
198206
n_points))
@@ -207,6 +215,7 @@
207215
"`GridNeighborhoodSearch` with `FullGridCellList` with `DynamicVectorOfVectors` and `SemiParallelUpdate`",
208216
"`GridNeighborhoodSearch` with `FullGridCellList` with `Vector{Vector}`",
209217
"`PrecomputedNeighborhoodSearch`",
218+
"`PrecomputedNeighborhoodSearch` with `Vector{Vector}`",
210219
"`GridNeighborhoodSearch` with `SpatialHashingCellList`"
211220
]
212221

@@ -227,6 +236,7 @@
227236
max_corner,
228237
backend = Vector{Vector{Int32}})),
229238
PrecomputedNeighborhoodSearch{NDIMS}(),
239+
PrecomputedNeighborhoodSearch{NDIMS}(backend = Vector{Vector{Int32}}),
230240
GridNeighborhoodSearch{NDIMS}(cell_list = SpatialHashingCellList{NDIMS}(2 *
231241
n_points))
232242
]

0 commit comments

Comments
 (0)