Skip to content

Commit a65dc9a

Browse files
authored
Fix the threaded update buffer (#114)
* Fix the threaded update buffer * Fix chunk size
1 parent 145853c commit a65dc9a

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

src/nhs_grid.jl

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ function update!(neighborhood_search::GridNeighborhoodSearch,
259259
end
260260

261261
# Update only with neighbor coordinates
262-
function update_grid!(neighborhood_search::GridNeighborhoodSearch{NDIMS}, y::AbstractMatrix;
263-
parallelization_backend = default_backend(y)) where {NDIMS}
262+
function update_grid!(neighborhood_search::GridNeighborhoodSearch, y::AbstractMatrix;
263+
parallelization_backend = default_backend(y))
264264
(; cell_list, update_buffer) = neighborhood_search
265265

266266
# Empty each thread's list
@@ -302,34 +302,46 @@ function update_grid!(neighborhood_search::GridNeighborhoodSearch{NDIMS}, y::Abs
302302
return neighborhood_search
303303
end
304304

305-
# The type annotation is to make Julia specialize on the type of the function.
306-
# Otherwise, unspecialized code will cause a lot of allocations and heavily impact performance.
307-
# See https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing
308305
@inline function mark_changed_cells!(neighborhood_search::GridNeighborhoodSearch{<:Any,
309306
SemiParallelUpdate},
310-
y, parallelization_backend) where {T}
311-
(; cell_list) = neighborhood_search
307+
y, parallelization_backend)
308+
(; cell_list, update_buffer) = neighborhood_search
312309

313310
# `each_cell_index(cell_list)` might return a `KeySet`, which has to be `collect`ed
314-
# first to be able to be used in a threaded loop. This function takes care of that.
315-
@threaded parallelization_backend for cell_index in
316-
each_cell_index_threadable(cell_list)
317-
mark_changed_cell!(neighborhood_search, cell_index, y)
311+
# first to support indexing.
312+
eachcell = each_cell_index_threadable(cell_list)
313+
314+
# Use chunks (usually one per thread) to index into the update buffer.
315+
# We cannot use `Iterators.partition` here, since the resulting iterator does not
316+
# support indexing and therefore cannot be used in a threaded loop.
317+
chunk_length = div(length(eachcell), length(update_buffer), RoundUp)
318+
319+
@threaded parallelization_backend for chunk_id in 1:length(update_buffer)
320+
# Manual partitioning of `eachcell`
321+
start = (chunk_length * (chunk_id - 1)) + 1
322+
end_ = min(chunk_length * chunk_id, length(eachcell))
323+
324+
for i in start:end_
325+
cell_index = eachcell[i]
326+
327+
mark_changed_cell!(neighborhood_search, cell_index, y, chunk_id)
328+
end
318329
end
319330
end
320331

321332
@inline function mark_changed_cells!(neighborhood_search::GridNeighborhoodSearch{<:Any,
322333
SerialIncrementalUpdate},
323-
y, _) where {T}
334+
y, _)
324335
(; cell_list) = neighborhood_search
325336

326337
# Ignore the parallelization backend here for `SerialIncrementalUpdate`.
327338
for cell_index in each_cell_index(cell_list)
328-
mark_changed_cell!(neighborhood_search, cell_index, y)
339+
# `chunk_id` is always `1` for `SerialIncrementalUpdate`
340+
mark_changed_cell!(neighborhood_search, cell_index, y, 1)
329341
end
330342
end
331343

332-
@inline function mark_changed_cell!(neighborhood_search, cell_index, y)
344+
@inline function mark_changed_cell!(neighborhood_search, cell_index, y, chunk_id)
333345
(; cell_list, update_buffer) = neighborhood_search
334346

335347
for point in cell_list[cell_index]
@@ -341,7 +353,7 @@ end
341353
# These can be identical (see `DictionaryCellList`).
342354
if !is_correct_cell(cell_list, cell, cell_index)
343355
# Mark this cell and continue with the next one
344-
pushat!(update_buffer, Threads.threadid(), cell_index)
356+
pushat!(update_buffer, chunk_id, cell_index)
345357
break
346358
end
347359
end

0 commit comments

Comments
 (0)