Skip to content

Commit a07e688

Browse files
Add Backend for SpatialHashingCellList (#111)
* Add support for different backends in SpatialHashingCellList. * Minor change, remove unnecessary call to prod(). * Add check_cell_bounds() for push_cell() for SHCL. Move check_cell_bounds() and construct_backends() to separate cell_lists_util.jl to share functionality between different cell lists. * Add constructor to SPCL to work with Adapt.jl * Minor changes. * Add SHCL to gpu.jl * Resolve requested changes: - Move functions from cell_lists_util.jl to cell_lists.jl - Change dispatch in `supported_update_strategies()` - Change doc string of SHCL (SpatialHashingCellList) * Resolve requested changes: - Add `@inbounds` in `push_cell_atomic!` - Improved type dispatch for `supported_update_strategies` - Clarified and cleaned up cell list initialization and emptying, - General code cleanup. * Resolve requested changes. * Replace hashing function for the coordinates. Make cell_list.coords primitive to use Atomix. * Add `adapt_structure()` for SHCL. * Use Atomix for updating cell's coords, * Minor changes, runs on GPU. * Add test for coordinate hash function. * Minor changes. * Resolve requested changes. - Update the tests - Improve the documentation - Minor changes --------- Co-authored-by: Erik Faulhaber <[email protected]>
1 parent 12ad0b3 commit a07e688

File tree

7 files changed

+224
-75
lines changed

7 files changed

+224
-75
lines changed

src/cell_lists/cell_lists.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,47 @@ abstract type AbstractCellList end
44
# able to be used in a threaded loop.
55
@inline each_cell_index_threadable(cell_list::AbstractCellList) = each_cell_index(cell_list)
66

7+
@inline function check_cell_bounds(cell_list::AbstractCellList, cell::Integer)
8+
(; cells) = cell_list
9+
10+
checkbounds(cells, cell)
11+
end
12+
13+
function construct_backend(::Type{Vector{Vector{T}}},
14+
max_outer_length,
15+
max_inner_length) where {T}
16+
return [T[] for _ in 1:max_outer_length]
17+
end
18+
19+
function construct_backend(::Type{DynamicVectorOfVectors{T}},
20+
max_outer_length,
21+
max_inner_length) where {T}
22+
cells = DynamicVectorOfVectors{T}(max_outer_length = max_outer_length,
23+
max_inner_length = max_inner_length)
24+
resize!(cells, max_outer_length)
25+
26+
return cells
27+
end
28+
29+
# When `typeof(cell_list.cells)` is passed, we don't pass the type
30+
# `DynamicVectorOfVectors{T}`, but a type `DynamicVectorOfVectors{T1, T2, T3, T4}`.
31+
# While `A{T} <: A{T1, T2}`, this doesn't hold for the types.
32+
# `Type{A{T}} <: Type{A{T1, T2}}` is NOT true.
33+
function construct_backend(::Type{DynamicVectorOfVectors{T1, T2, T3, T4}}, max_outer_length,
34+
max_inner_length) where {T1, T2, T3, T4}
35+
return construct_backend(DynamicVectorOfVectors{T1}, max_outer_length,
36+
max_inner_length)
37+
end
38+
39+
function max_points_per_cell(cells::DynamicVectorOfVectors)
40+
return size(cells.backend, 1)
41+
end
42+
43+
# Fallback when backend is a `Vector{Vector{T}}`. Only used for copying the cell list.
44+
function max_points_per_cell(cells)
45+
return 100
46+
end
47+
748
include("dictionary.jl")
849
include("full_grid.jl")
950
include("spatial_hashing.jl")

src/cell_lists/full_grid.jl

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -62,34 +62,13 @@ function FullGridCellList(; min_corner, max_corner,
6262
n_cells_per_dimension = ceil.(Int, (max_corner .- min_corner) ./ search_radius)
6363
linear_indices = LinearIndices(Tuple(n_cells_per_dimension))
6464

65-
cells = construct_backend(backend, n_cells_per_dimension, max_points_per_cell)
65+
cells = construct_backend(backend, prod(n_cells_per_dimension),
66+
max_points_per_cell)
6667
end
6768

6869
return FullGridCellList(cells, linear_indices, min_corner, max_corner)
6970
end
7071

71-
function construct_backend(::Type{Vector{Vector{T}}}, size, max_points_per_cell) where {T}
72-
return [T[] for _ in 1:prod(size)]
73-
end
74-
75-
function construct_backend(::Type{DynamicVectorOfVectors{T}}, size,
76-
max_points_per_cell) where {T}
77-
cells = DynamicVectorOfVectors{T}(max_outer_length = prod(size),
78-
max_inner_length = max_points_per_cell)
79-
resize!(cells, prod(size))
80-
81-
return cells
82-
end
83-
84-
# When `typeof(cell_list.cells)` is passed, we don't pass the type
85-
# `DynamicVectorOfVectors{T}`, but a type `DynamicVectorOfVectors{T1, T2, T3, T4}`.
86-
# While `A{T} <: A{T1, T2}`, this doesn't hold for the types.
87-
# `Type{A{T}} <: Type{A{T1, T2}}` is NOT true.
88-
function construct_backend(::Type{DynamicVectorOfVectors{T1, T2, T3, T4}}, size,
89-
max_points_per_cell) where {T1, T2, T3, T4}
90-
return construct_backend(DynamicVectorOfVectors{T1}, size, max_points_per_cell)
91-
end
92-
9372
@inline function cell_coords(coords, periodic_box::Nothing, cell_list::FullGridCellList,
9473
cell_size)
9574
(; min_corner) = cell_list
@@ -210,15 +189,6 @@ function copy_cell_list(cell_list::FullGridCellList, search_radius, periodic_box
210189
max_points_per_cell = max_points_per_cell(cell_list.cells))
211190
end
212191

213-
function max_points_per_cell(cells::DynamicVectorOfVectors)
214-
return size(cells.backend, 1)
215-
end
216-
217-
# Fallback when backend is a `Vector{Vector{T}}`. Only used for copying the cell list.
218-
function max_points_per_cell(cells)
219-
return 100
220-
end
221-
222192
@inline function check_cell_bounds(cell_list::FullGridCellList{<:DynamicVectorOfVectors{<:Any,
223193
<:Array}},
224194
cell::Tuple)
@@ -246,9 +216,3 @@ end
246216
error("particle coordinates are NaN or outside the domain bounds of the cell list")
247217
end
248218
end
249-
250-
@inline function check_cell_bounds(cell_list::FullGridCellList, cell::Integer)
251-
(; cells) = cell_list
252-
253-
checkbounds(cells, cell)
254-
end

src/cell_lists/spatial_hashing.jl

Lines changed: 95 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""
2-
SpatialHashingCellList{NDIMS}(; list_size)
2+
SpatialHashingCellList{NDIMS}(; list_size,
3+
backend = DynamicVectorOfVectors{Int32},
4+
max_points_per_cell = 100)
35
46
A basic spatial hashing implementation. Similar to [`DictionaryCellList`](@ref), the domain is discretized into cells,
57
and the particles in each cell are stored in a hash map. The hash is computed using the spatial location of each cell,
@@ -9,80 +11,132 @@ to balance memory consumption against the likelihood of hash collisions.
911
1012
# Arguments
1113
- `NDIMS::Int`: Number of spatial dimensions (e.g., `2` or `3`).
12-
- `list_size::Int`: Size of the hash map (e.g., `2 * n_points`) .
14+
- `list_size::Int`: Size of the hash map (e.g., `2 * n_points`).
15+
- `backend = DynamicVectorOfVectors{Int32}`: Type of the data structure to store the actual
16+
cell lists. Can be
17+
- `Vector{Vector{Int32}}`: Scattered memory, but very memory-efficient.
18+
- `DynamicVectorOfVectors{Int32}`: Contiguous memory, optimizing cache-hits.
19+
- `max_points_per_cell = 100`: Maximum number of points per cell. This will be used to
20+
allocate the `DynamicVectorOfVectors`. It is not used with
21+
the `Vector{Vector{Int32}}` backend.
1322
"""
1423

1524
struct SpatialHashingCellList{NDIMS, CL, CI, CF} <: AbstractCellList
16-
points :: CL
25+
cells :: CL
1726
coords :: CI
1827
collisions :: CF
1928
list_size :: Int
29+
30+
# This constructor is necessary for Adapt.jl to work with this struct
31+
function SpatialHashingCellList(NDIMS, cells, coords, collisions, list_size)
32+
return new{NDIMS, typeof(cells),
33+
typeof(coords), typeof(collisions)}(cells, coords,
34+
collisions, list_size)
35+
end
2036
end
2137

2238
@inline index_type(::SpatialHashingCellList) = Int32
2339

2440
@inline Base.ndims(::SpatialHashingCellList{NDIMS}) where {NDIMS} = NDIMS
2541

42+
function supported_update_strategies(::SpatialHashingCellList{<:Any,
43+
<:DynamicVectorOfVectors})
44+
return (ParallelUpdate, SerialUpdate)
45+
end
46+
2647
function supported_update_strategies(::SpatialHashingCellList)
2748
return (SerialUpdate,)
2849
end
2950

30-
function SpatialHashingCellList{NDIMS}(list_size) where {NDIMS}
31-
points = [Int[] for _ in 1:list_size]
51+
function SpatialHashingCellList{NDIMS}(; list_size,
52+
backend = DynamicVectorOfVectors{Int32},
53+
max_points_per_cell = 100) where {NDIMS}
54+
cells = construct_backend(backend, list_size,
55+
max_points_per_cell)
3256
collisions = [false for _ in 1:list_size]
33-
coords = [ntuple(_ -> typemin(Int), NDIMS) for _ in 1:list_size]
34-
return SpatialHashingCellList{NDIMS, typeof(points), typeof(coords),
35-
typeof(collisions)}(points, coords, collisions, list_size)
57+
coords = [typemin(UInt128) for _ in 1:list_size]
58+
59+
return SpatialHashingCellList(NDIMS, cells, coords, collisions, list_size)
3660
end
3761

3862
function Base.empty!(cell_list::SpatialHashingCellList)
39-
(; list_size) = cell_list
63+
(; cells) = cell_list
4064
NDIMS = ndims(cell_list)
4165

42-
Base.empty!.(cell_list.points)
43-
cell_list.coords .= [ntuple(_ -> typemin(Int), NDIMS) for _ in 1:list_size]
66+
# `Base.empty!.(cells)`, but for all backends
67+
@threaded default_backend(cells) for i in eachindex(cells)
68+
emptyat!(cells, i)
69+
end
70+
71+
fill!(cell_list.coords, typemin(UInt128))
4472
cell_list.collisions .= false
4573
return cell_list
4674
end
4775

4876
# For each entry in the hash table, store the coordinates of the cell of the first point being inserted at this entry.
4977
# If a point with a different cell coordinate is being added, we have found a collision.
78+
# We flatten the coordinate tuples to an `UInt128` number to make it work with atomics.
5079
function push_cell!(cell_list::SpatialHashingCellList, cell, point)
51-
(; points, coords, collisions, list_size) = cell_list
80+
(; cells, coords, collisions, list_size) = cell_list
5281
NDIMS = ndims(cell_list)
5382
hash_key = spatial_hash(cell, list_size)
54-
push!(points[hash_key], point)
5583

56-
cell_coord = coords[hash_key]
57-
if cell_coord == ntuple(_ -> typemin(Int), NDIMS)
84+
@boundscheck check_cell_bounds(cell_list, hash_key)
85+
@inbounds pushat!(cells, hash_key, point)
86+
87+
cell_coord_hash = coordinates_flattened(cell)
88+
previous_cell_coord = coords[hash_key]
89+
if previous_cell_coord == typemin(UInt128)
5890
# If this cell is not used yet, set cell coordinates
59-
coords[hash_key] = cell
60-
elseif cell_coord != cell
91+
coords[hash_key] = cell_coord_hash
92+
elseif previous_cell_coord != cell_coord_hash
6193
# If it is already used by a different cell, mark as collision
6294
collisions[hash_key] = true
6395
end
6496
end
6597

98+
function push_cell_atomic!(cell_list::SpatialHashingCellList, cell, point)
99+
(; cells, coords, collisions, list_size) = cell_list
100+
NDIMS = ndims(cell_list)
101+
hash_key = spatial_hash(cell, list_size)
102+
103+
cell_coord_hash = coordinates_flattened(cell)
104+
105+
@boundscheck check_cell_bounds(cell_list, hash_key)
106+
@inbounds pushat_atomic!(cells, hash_key, point)
107+
108+
cell_coord = @inbounds coords[hash_key]
109+
if cell_coord == ntuple(_ -> typemin(UInt128), Val(NDIMS))
110+
# If this cell is not used yet, set cell coordinates
111+
@inbounds Atomix.@atomic coords[hash_key] = cell_coord_hash
112+
elseif cell_coord != cell_coord_hash
113+
# If it is already used by a different cell, mark as collision
114+
@inbounds Atomix.@atomic collisions[hash_key] = true
115+
end
116+
end
117+
66118
function deleteat_cell!(cell_list::SpatialHashingCellList, cell, i)
67119
deleteat!(cell_list[cell], i)
68120
end
69121

70-
@inline each_cell_index(cell_list::SpatialHashingCellList) = eachindex(cell_list.points)
122+
@inline each_cell_index(cell_list::SpatialHashingCellList) = eachindex(cell_list.cells)
71123

72124
function copy_cell_list(cell_list::SpatialHashingCellList, search_radius,
73125
periodic_box)
74126
(; list_size) = cell_list
75127
NDIMS = ndims(cell_list)
76128

77-
return SpatialHashingCellList{NDIMS}(list_size)
129+
return SpatialHashingCellList{NDIMS}(list_size = list_size,
130+
backend = typeof(cell_list.cells),
131+
max_points_per_cell = max_points_per_cell(cell_list.cells))
78132
end
79133

80134
@inline function Base.getindex(cell_list::SpatialHashingCellList, cell::Tuple)
81-
return cell_list.points[spatial_hash(cell, length(cell_list.points))]
135+
return cell_list.cells[spatial_hash(cell, length(cell_list.cells))]
82136
end
83137

84138
@inline function Base.getindex(cell_list::SpatialHashingCellList, i::Integer)
85-
return cell_list.points[i]
139+
return cell_list.cells[i]
86140
end
87141

88142
@inline function is_correct_cell(cell_list::SpatialHashingCellList{<:Any, Nothing},
@@ -106,3 +160,23 @@ function spatial_hash(cell::NTuple{3, Real}, list_size)
106160

107161
return mod(xor(i * 73856093, j * 19349663, k * 83492791), list_size) + 1
108162
end
163+
164+
@inline function check_cell_bounds(cell_list::SpatialHashingCellList, cell::Tuple)
165+
check_cell_bounds(cell_list, spatial_hash(cell, cell_list.list_size))
166+
end
167+
168+
# Compute a compact 128-bit representation by reinterpreting each coordinate as a UInt32
169+
# and bit-shifting them into a UInt128 slot (appending the `UInt32` bitstrings).
170+
function coordinates_flattened(cell_coordinate)
171+
# Size check
172+
@assert length(cell_coordinate) <= 3
173+
174+
result = UInt128(0)
175+
for (i, coord) in enumerate(cell_coordinate)
176+
ucoord = reinterpret(UInt32, Int32(coord))
177+
# Shift the `i`-th coordinate by (i - 1) x 32 bits, so the used bits don't overlap
178+
result = (UInt128(ucoord) << ((i-1) * 32)) | result
179+
end
180+
181+
return result
182+
end

src/gpu.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,12 @@ function Adapt.adapt_structure(to, nhs::GridNeighborhoodSearch)
3131
return GridNeighborhoodSearch(cell_list, search_radius, periodic_box, n_cells,
3232
cell_size, update_buffer, nhs.update_strategy)
3333
end
34+
35+
function Adapt.adapt_structure(to, cell_list::SpatialHashingCellList{NDIMS}) where {NDIMS}
36+
(; list_size) = cell_list
37+
cells = Adapt.adapt_structure(to, cell_list.cells)
38+
coords = Adapt.adapt_structure(to, cell_list.coords)
39+
collisions = Adapt.adapt_structure(to, cell_list.collisions)
40+
41+
return SpatialHashingCellList(NDIMS, cells, coords, collisions, list_size)
42+
end

src/nhs_grid.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ end
390390
end
391391

392392
# Fully parallel incremental update with atomic push.
393+
# TODO `cell_list.cells.lengths` and `cell_list.cells.backend` are hardcoded
394+
# for `FullGridCellList`, which is currently the only implementation
393395
function update_grid!(neighborhood_search::GridNeighborhoodSearch{<:Any,
394396
ParallelIncrementalUpdate},
395397
y::AbstractMatrix; parallelization_backend = default_backend(y),
@@ -470,7 +472,6 @@ end
470472
# with the `SpatialHashingCellList` if this cell has a collision.
471473
function check_collision(neighbor_cell_::CartesianIndex, neighbor_coords,
472474
cell_list::SpatialHashingCellList, nhs)
473-
(; list_size, collisions, coords) = cell_list
474475
neighbor_cell = periodic_cell_index(Tuple(neighbor_cell_), nhs)
475476

476477
return neighbor_cell != cell_coords(neighbor_coords, nhs)
@@ -494,7 +495,8 @@ function check_cell_collision(neighbor_cell_::CartesianIndex,
494495
# `collisions[hash] == false` means points from only one cells are in this list.
495496
# We could still have a collision though, if this one cell is not `neighbor_cell`,
496497
# which is possible when `neighbor_cell` is empty.
497-
return collisions[hash] || coords[hash] != neighbor_cell
498+
return collisions[hash] ||
499+
coords[hash] != PointNeighbors.coordinates_flattened(neighbor_cell)
498500
end
499501

500502
# Specialized version of the function in `neighborhood_search.jl`, which is faster

0 commit comments

Comments
 (0)