Skip to content

Commit f1d6458

Browse files
committed
Add function to strip not GPU-compatible data structures
1 parent d114c8a commit f1d6458

File tree

4 files changed

+22
-2
lines changed

4 files changed

+22
-2
lines changed

benchmarks/plot.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,10 @@ function plot_benchmarks(benchmark, n_points_per_dimension, iterations;
6868
for i in eachindex(neighborhood_searches)
6969
neighborhood_search = neighborhood_searches[i]
7070
initialize!(neighborhood_search, coordinates, coordinates)
71+
# Remove unnecessary data structures that are only used for initialization
72+
neighborhood_search_ = PointNeighbors.freeze_neighborhood_search(neighborhood_search)
7173

72-
time = benchmark(neighborhood_search, coordinates; parallelization_backend)
74+
time = benchmark(neighborhood_search_, coordinates; parallelization_backend)
7375
times[iter, i] = time
7476
time_string = BenchmarkTools.prettytime(time * 1e9)
7577
println("$(neighborhood_searches_names[i])")

src/gpu.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function Adapt.adapt_structure(to, nhs::PrecomputedNeighborhoodSearch)
2828
neighbor_lists = Adapt.adapt_structure(to, nhs.neighbor_lists)
2929
search_radius = Adapt.adapt_structure(to, nhs.search_radius)
3030
periodic_box = Adapt.adapt_structure(to, nhs.periodic_box)
31-
neighborhood_search = nothing# Adapt.adapt_structure(to, nhs.neighborhood_search)
31+
neighborhood_search = Adapt.adapt_structure(to, nhs.neighborhood_search)
3232

3333
return PrecomputedNeighborhoodSearch{ndims(nhs)}(neighbor_lists, search_radius,
3434
periodic_box, neighborhood_search)

src/neighborhood_search.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ GridNeighborhoodSearch{2, Float64, ...}(...)
108108
return nothing
109109
end
110110

111+
@inline function freeze_neighborhood_search(search::AbstractNeighborhoodSearch)
112+
# Indicate that the neighborhood search is static and will not be updated anymore.
113+
# Some implementations might use this to strip unnecessary data structures for updating.
114+
# Notably, this is used for the `PrecomputedNeighborhoodSearch` to strip a potentially
115+
# not GPU-compatible inner neighborhood search that is used only for initialization.
116+
return search
117+
end
118+
111119
"""
112120
PeriodicBox(; min_corner, max_corner)
113121

src/nhs_precomputed.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,16 @@ function copy_neighborhood_search(nhs::PrecomputedNeighborhoodSearch,
190190
max_neighbors)
191191
end
192192

193+
@inline function freeze_neighborhood_search(search::PrecomputedNeighborhoodSearch)
194+
# Indicate that the neighborhood search is static and will not be updated anymore.
195+
# For the `PrecomputedNeighborhoodSearch`, strip the inner neighborhood search,
196+
# which is used only for initialization and updating.
197+
return PrecomputedNeighborhoodSearch{ndims(search)}(search.neighbor_lists,
198+
search.search_radius,
199+
search.periodic_box,
200+
nothing)
201+
end
202+
193203
# TODO move to `vector_of_vectors.jl`
194204
function max_inner_length(cells::DynamicVectorOfVectors, fallback)
195205
return size(cells.backend, 1)

0 commit comments

Comments
 (0)