diff --git a/src/neighborhood_search.jl b/src/neighborhood_search.jl index 9fb6dcbb..36427896 100644 --- a/src/neighborhood_search.jl +++ b/src/neighborhood_search.jl @@ -35,7 +35,7 @@ See also [`update!`](@ref). """ @inline function initialize!(search::AbstractNeighborhoodSearch, x, y; parallelization_backend = default_backend(x), - eachindex_y = axes(y, 2)) + eachindex_y = axes(y, 2), points_active = nothing) return search end @@ -73,7 +73,7 @@ See also [`initialize!`](@ref). @inline function update!(search::AbstractNeighborhoodSearch, x, y; points_moving = (true, true), parallelization_backend = default_backend(x), - eachindex_y = axes(y, 2)) + eachindex_y = axes(y, 2), points_active = nothing) return search end @@ -171,7 +171,8 @@ See also [`initialize!`](@ref), [`update!`](@ref). """ function foreach_point_neighbor(f::T, system_coords, neighbor_coords, neighborhood_search; parallelization_backend::ParallelizationBackend = default_backend(system_coords), - points = axes(system_coords, 2)) where {T} + points = axes(system_coords, 2), + points_active = nothing) where {T} # The type annotation above is to make Julia specialize on the type of the function. # Otherwise, unspecialized code will cause a lot of allocations # and heavily impact performance. @@ -181,14 +182,20 @@ function foreach_point_neighbor(f::T, system_coords, neighbor_coords, neighborho @boundscheck checkbounds(system_coords, ndims(neighborhood_search), points) @threaded parallelization_backend for point in points - # Now we can safely assume that `point` is inbounds - @inbounds foreach_neighbor(f, system_coords, neighbor_coords, - neighborhood_search, point) + # This check is optimized away when `points_active` is `nothing` + if point_active(points_active, point) + # After the explicit boundscheck, we can safely assume that `point` is inbounds + @inbounds foreach_neighbor(f, system_coords, neighbor_coords, + neighborhood_search, point) + end end return nothing end +@inline point_active(::Nothing, _) = true +@inline point_active(points_active, point) = points_active[point] == true + @propagate_inbounds function foreach_neighbor(f, system_coords, neighbor_system_coords, neighborhood_search::AbstractNeighborhoodSearch, point; diff --git a/src/nhs_grid.jl b/src/nhs_grid.jl index 08139ba9..ed878830 100644 --- a/src/nhs_grid.jl +++ b/src/nhs_grid.jl @@ -195,13 +195,14 @@ end function initialize!(neighborhood_search::GridNeighborhoodSearch, x::AbstractMatrix, y::AbstractMatrix; parallelization_backend = default_backend(x), - eachindex_y = axes(y, 2)) - initialize_grid!(neighborhood_search, y; parallelization_backend, eachindex_y) + eachindex_y = axes(y, 2), points_active = nothing) + initialize_grid!(neighborhood_search, y; parallelization_backend, + eachindex_y, points_active) end function initialize_grid!(neighborhood_search::GridNeighborhoodSearch, y::AbstractMatrix; parallelization_backend = default_backend(y), - eachindex_y = axes(y, 2)) + eachindex_y = axes(y, 2), points_active = nothing) (; cell_list) = neighborhood_search empty!(cell_list) @@ -216,12 +217,16 @@ function initialize_grid!(neighborhood_search::GridNeighborhoodSearch, y::Abstra # Ignore the parallelization backend here. This cannot be parallelized. for point in eachindex_y - # Get cell index of the point's cell - point_coords = @inbounds extract_svector(y, Val(ndims(neighborhood_search)), point) - cell = cell_coords(point_coords, neighborhood_search) - - # Add point to corresponding cell - push_cell!(cell_list, cell, point) + # This check is optimized away when `points_active` is `nothing` + if point_active(points_active, point) + # Get cell index of the point's cell + point_coords = @inbounds extract_svector(y, Val(ndims(neighborhood_search)), + point) + cell = cell_coords(point_coords, neighborhood_search) + + # Add point to corresponding cell + push_cell!(cell_list, cell, point) + end end return neighborhood_search @@ -230,7 +235,7 @@ end function initialize_grid!(neighborhood_search::GridNeighborhoodSearch{<:Any, ParallelUpdate}, y::AbstractMatrix; parallelization_backend = default_backend(y), - eachindex_y = axes(y, 2)) + eachindex_y = axes(y, 2), points_active = nothing) (; cell_list) = neighborhood_search empty!(cell_list) @@ -244,12 +249,16 @@ function initialize_grid!(neighborhood_search::GridNeighborhoodSearch{<:Any, @boundscheck checkbounds(y, eachindex_y) @threaded parallelization_backend for point in eachindex_y - # Get cell index of the point's cell - point_coords = @inbounds extract_svector(y, Val(ndims(neighborhood_search)), point) - cell = cell_coords(point_coords, neighborhood_search) - - # Add point to corresponding cell - push_cell_atomic!(cell_list, cell, point) + # This check is optimized away when `points_active` is `nothing` + if point_active(points_active, point) + # Get cell index of the point's cell + point_coords = @inbounds extract_svector(y, Val(ndims(neighborhood_search)), + point) + cell = cell_coords(point_coords, neighborhood_search) + + # Add point to corresponding cell + push_cell_atomic!(cell_list, cell, point) + end end return neighborhood_search @@ -258,12 +267,13 @@ end function update!(neighborhood_search::GridNeighborhoodSearch, x::AbstractMatrix, y::AbstractMatrix; points_moving = (true, true), parallelization_backend = default_backend(x), - eachindex_y = axes(y, 2)) + eachindex_y = axes(y, 2), points_active = nothing) # The coordinates of the first set of points are irrelevant for this NHS. # Only update when the second set is moving. points_moving[2] || return neighborhood_search - update_grid!(neighborhood_search, y; eachindex_y, parallelization_backend) + update_grid!(neighborhood_search, y; eachindex_y, points_active, + parallelization_backend) end # Update only with neighbor coordinates @@ -273,10 +283,10 @@ function update_grid!(neighborhood_search::Union{GridNeighborhoodSearch{<:Any, SemiParallelUpdate}}, y::AbstractMatrix; parallelization_backend = default_backend(y), - eachindex_y = axes(y, 2)) + eachindex_y = axes(y, 2), points_active = nothing) (; cell_list, update_buffer) = neighborhood_search - if eachindex_y != axes(y, 2) + if eachindex_y != axes(y, 2) || points_active !== nothing # Incremental update doesn't support inactive points error("this neighborhood search/update strategy does not support inactive points") end @@ -381,10 +391,10 @@ end function update_grid!(neighborhood_search::GridNeighborhoodSearch{<:Any, ParallelIncrementalUpdate}, y::AbstractMatrix; parallelization_backend = default_backend(y), - eachindex_y = axes(y, 2)) + eachindex_y = axes(y, 2), points_active = nothing) (; cell_list, update_buffer) = neighborhood_search - if eachindex_y != axes(y, 2) + if eachindex_y != axes(y, 2) || points_active !== nothing # Incremental update doesn't support inactive points error("this neighborhood search/update strategy does not support inactive points") end @@ -445,8 +455,9 @@ function update_grid!(neighborhood_search::Union{GridNeighborhoodSearch{<:Any, GridNeighborhoodSearch{<:Any, SerialUpdate}}, y::AbstractMatrix; parallelization_backend = default_backend(y), - eachindex_y = axes(y, 2)) - initialize_grid!(neighborhood_search, y; parallelization_backend, eachindex_y) + eachindex_y = axes(y, 2), points_active = nothing) + initialize_grid!(neighborhood_search, y; parallelization_backend, + eachindex_y, points_active) end # Specialized version of the function in `neighborhood_search.jl`, which is faster diff --git a/src/nhs_precomputed.jl b/src/nhs_precomputed.jl index 69749f1c..1f3acf73 100644 --- a/src/nhs_precomputed.jl +++ b/src/nhs_precomputed.jl @@ -54,14 +54,14 @@ end function initialize!(search::PrecomputedNeighborhoodSearch, x::AbstractMatrix, y::AbstractMatrix; parallelization_backend = default_backend(x), - eachindex_y = axes(y, 2)) + eachindex_y = axes(y, 2), points_active = nothing) (; neighborhood_search, neighbor_lists) = search # Initialize grid NHS initialize!(neighborhood_search, x, y; eachindex_y, parallelization_backend) initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y, - parallelization_backend, eachindex_y) + parallelization_backend, eachindex_y, points_active) end # WARNING! Experimental feature: @@ -72,21 +72,22 @@ end function update!(search::PrecomputedNeighborhoodSearch, x::AbstractMatrix, y::AbstractMatrix; points_moving = (true, true), parallelization_backend = default_backend(x), - eachindex_y = axes(y, 2)) + eachindex_y = axes(y, 2), points_active = nothing) (; neighborhood_search, neighbor_lists) = search # Update grid NHS - update!(neighborhood_search, x, y; eachindex_y, points_moving, parallelization_backend) + update!(neighborhood_search, x, y; eachindex_y, points_moving, points_active, + parallelization_backend) # Skip update if both point sets are static if any(points_moving) initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y, - parallelization_backend, eachindex_y) + parallelization_backend, eachindex_y, points_active) end end function initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y, - parallelization_backend, eachindex_y) + parallelization_backend, eachindex_y, points_active) # Initialize neighbor lists empty!(neighbor_lists) resize!(neighbor_lists, size(x, 2)) @@ -95,8 +96,9 @@ function initialize_neighbor_lists!(neighbor_lists, neighborhood_search, x, y, end # Fill neighbor lists - foreach_point_neighbor(x, y, neighborhood_search; parallelization_backend, - points = eachindex_y) do point, neighbor, _, _ + foreach_point_neighbor(x, y, neighborhood_search; + parallelization_backend, points = eachindex_y, + points_active) do point, neighbor, _, _ push!(neighbor_lists[point], neighbor) end end diff --git a/src/nhs_trivial.jl b/src/nhs_trivial.jl index 814ac03a..584564ad 100644 --- a/src/nhs_trivial.jl +++ b/src/nhs_trivial.jl @@ -34,14 +34,14 @@ end @inline function initialize!(search::TrivialNeighborhoodSearch, x, y; parallelization_backend = default_backend(x), - eachindex_y = axes(y, 2)) + eachindex_y = axes(y, 2), points_active = nothing) return search end @inline function update!(search::TrivialNeighborhoodSearch, x, y; points_moving = (true, true), parallelization_backend = default_backend(x), - eachindex_y = axes(y, 2)) + eachindex_y = axes(y, 2), points_active = nothing) return search end