- 
                Notifications
    You must be signed in to change notification settings 
- Fork 8
Implement neighborhood search based on CellListMap.jl #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 15 commits
82350dd
              f7f48e8
              3002d46
              357ebd2
              c8f1ce2
              1d1c078
              d06e48f
              67d5422
              e3a9637
              b874b5f
              0ba9dad
              aad3e13
              e6e374f
              cbf25c5
              7779a7e
              4e4ec1d
              39ac331
              fbf2e99
              d7000c8
              1326b52
              d183c14
              24f0c85
              b5aca7d
              58f5b57
              fabafbb
              2cc18a8
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,187 @@ | ||||||
| module PointNeighborsCellListMapExt | ||||||
|  | ||||||
| using PointNeighbors | ||||||
| using CellListMap: CellListMap, CellList, CellListPair | ||||||
|  | ||||||
| """ | ||||||
| CellListMapNeighborhoodSearch(NDIMS; search_radius = 1.0, points_equal_neighbors = false) | ||||||
|  | ||||||
| Neighborhood search based on the package [CellListMap.jl](https://github.com/m3g/CellListMap.jl). | ||||||
| This package provides a similar implementation to the [`GridNeighborhoodSearch`](@ref) | ||||||
| with [`FullGridCellList`](@ref), but with better support for periodic boundaries. | ||||||
| This is just a wrapper to use CellListMap.jl with the PointNeighbors.jl API. | ||||||
| Note that periodic boundaries are not yet supported. | ||||||
|  | ||||||
| # Arguments | ||||||
| - `NDIMS`: Number of dimensions. | ||||||
|  | ||||||
| # Keywords | ||||||
| - `search_radius = 1.0`: The fixed search radius. The default of `1.0` is useful together | ||||||
| with [`copy_neighborhood_search`](@ref). | ||||||
| - `points_equal_neighbors = false`: If `true`, a `CellListMap.CellList` is used instead of | ||||||
| a `CellListMap.CellListPair`. This requires that `x === y` | ||||||
| in [`initialize!`](@ref) and [`update!`](@ref). | ||||||
| This option exists only for benchmarking purposes. | ||||||
| It makes the main loop awkward because CellListMap.jl | ||||||
| only computes pairs with `i < j` and PointNeighbors.jl | ||||||
| computes all pairs, so we have to manually use symmetry | ||||||
| to add the missing pairs. | ||||||
|  | ||||||
| !!! warning "Experimental implementation" | ||||||
| This is an experimental feature and may change in future releases. | ||||||
| """ | ||||||
| mutable struct CellListMapNeighborhoodSearch{CL, B} | ||||||
| cell_list :: CL | ||||||
| box :: B | ||||||
|  | ||||||
| # Add dispatch on `NDIMS` to avoid method overwriting of the function in PointNeighbors.jl | ||||||
| function PointNeighbors.CellListMapNeighborhoodSearch(NDIMS::Integer; | ||||||
| search_radius = 1.0, | ||||||
| points_equal_neighbors = false) | ||||||
| # Create a cell list with only one point and resize it later | ||||||
| x = zeros(NDIMS, 1) | ||||||
| box = CellListMap.Box(CellListMap.limits(x, x), search_radius) | ||||||
|  | ||||||
| if points_equal_neighbors | ||||||
| cell_list = CellListMap.CellList(x, box) | ||||||
| else | ||||||
| cell_list = CellListMap.CellList(x, x, box) | ||||||
| end | ||||||
|  | ||||||
| return new{typeof(cell_list), typeof(box)}(cell_list, box) | ||||||
| end | ||||||
| end | ||||||
|  | ||||||
| function PointNeighbors.search_radius(neighborhood_search::CellListMapNeighborhoodSearch) | ||||||
| return neighborhood_search.box.cutoff | ||||||
| end | ||||||
|  | ||||||
| function Base.ndims(neighborhood_search::CellListMapNeighborhoodSearch) | ||||||
| return length(neighborhood_search.box.cell_size) | ||||||
| end | ||||||
|  | ||||||
| function PointNeighbors.initialize!(neighborhood_search::CellListMapNeighborhoodSearch, | ||||||
| x::AbstractMatrix, y::AbstractMatrix) | ||||||
| PointNeighbors.update!(neighborhood_search, x, y) | ||||||
| end | ||||||
|  | ||||||
| # When `x !== y`, a `CellListPair` must be used | ||||||
| function PointNeighbors.update!(neighborhood_search::CellListMapNeighborhoodSearch{<:CellListPair}, | ||||||
| x::AbstractMatrix, y::AbstractMatrix; | ||||||
| points_moving = (true, true)) | ||||||
| (; cell_list) = neighborhood_search | ||||||
|  | ||||||
| # Resize box | ||||||
| box = CellListMap.Box(CellListMap.limits(x, y), neighborhood_search.box.cutoff) | ||||||
| neighborhood_search.box = box | ||||||
|  | ||||||
| # Resize and update cell list | ||||||
| CellListMap.UpdateCellList!(x, y, box, cell_list) | ||||||
|  | ||||||
| # Recalculate number of batches for multithreading | ||||||
| CellListMap.set_number_of_batches!(cell_list) | ||||||
|  | ||||||
| return neighborhood_search | ||||||
| end | ||||||
|  | ||||||
| # When `points_equal_neighbors == true`, a `CellList` is used and `x` must be equal to `y` | ||||||
| function PointNeighbors.update!(neighborhood_search::CellListMapNeighborhoodSearch{<:CellList}, | ||||||
| x::AbstractMatrix, y::AbstractMatrix; | ||||||
| points_moving = (true, true)) | ||||||
| (; cell_list) = neighborhood_search | ||||||
|  | ||||||
| @assert x===y "When `points_equal_neighbors == true`, `x` must be equal to `y`" | ||||||
|  | ||||||
| # Resize box | ||||||
| box = CellListMap.Box(CellListMap.limits(x), neighborhood_search.box.cutoff) | ||||||
| neighborhood_search.box = box | ||||||
|  | ||||||
| # Resize and update cell list | ||||||
| CellListMap.UpdateCellList!(x, box, cell_list) | ||||||
|  | ||||||
| # Recalculate number of batches for multithreading | ||||||
| CellListMap.set_number_of_batches!(cell_list) | ||||||
|  | ||||||
| # Due to https://github.com/m3g/CellListMap.jl/issues/106, we have to update again | ||||||
| CellListMap.UpdateCellList!(x, box, cell_list) | ||||||
|  | ||||||
| return neighborhood_search | ||||||
| end | ||||||
|  | ||||||
| # The type annotation is to make Julia specialize on the type of the function. | ||||||
| # Otherwise, unspecialized code will cause a lot of allocations | ||||||
| # and heavily impact performance. | ||||||
| # See https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing | ||||||
| function PointNeighbors.foreach_point_neighbor(f::T, system_coords, neighbor_coords, | ||||||
| neighborhood_search::CellListMapNeighborhoodSearch{<:CellListPair}; | ||||||
| points = axes(system_coords, 2), | ||||||
| parallel = true) where {T} | ||||||
| (; cell_list, box) = neighborhood_search | ||||||
|  | ||||||
| # `0` is the returned output, which we don't use. | ||||||
| # Note that `parallel !== false` is `true` when `parallel` is a PointNeighbors backend. | ||||||
| CellListMap.map_pairwise!(0, box, cell_list, | ||||||
| parallel = parallel !== false) do x, y, i, j, d2, output | ||||||
|         
                  LasNikas marked this conversation as resolved.
              Show resolved
            Hide resolved | ||||||
| # Skip all indices not in `points` | ||||||
| i in points || return output | ||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 Doesn't this produce a notable overhead which affects the benchmark? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried. It's negligible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just out of interest. Am I missing something? julia> f(i, points) = i in points
f (generic function with 1 method)
julia> points = rand(Int, 1_000_000);
julia> @benchmark f($5, $points)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  221.584 μs … 380.625 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     221.834 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   223.171 μs ±   4.460 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%
  █▅▂▄▁   ▅  ▁  ▁                                               ▁
  █████▇▆██▇██▆██▇▆▆▆▇▅▆▆▅▅▆▅▇▅▄▄▃▅▄▄▅▃▃▄▃▄▂▂▃▃▃▃▂▃▂▂▄▅▇▆▄▃▅▂▃▇ █
  222 μs        Histogram: log(frequency) by time        243 μs <
 Memory estimate: 0 bytes, allocs estimate: 0.There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The difference is that this is about 1-2% of the actual computation when you have a real example where you are doing actual work per particle. | ||||||
|  | ||||||
| pos_diff = x - y | ||||||
| distance = sqrt(d2) | ||||||
|  | ||||||
| @inline f(i, j, pos_diff, distance) | ||||||
|  | ||||||
| return output | ||||||
| end | ||||||
|  | ||||||
| return nothing | ||||||
| end | ||||||
|  | ||||||
| function PointNeighbors.foreach_point_neighbor(f::T, system_coords, neighbor_coords, | ||||||
| neighborhood_search::CellListMapNeighborhoodSearch{<:CellList}; | ||||||
| points = axes(system_coords, 2), | ||||||
| parallel = true) where {T} | ||||||
| (; cell_list, box) = neighborhood_search | ||||||
|  | ||||||
| # `0` is the returned output, which we don't use. | ||||||
| # Note that `parallel !== false` is `true` when `parallel` is a PointNeighbors backend. | ||||||
| CellListMap.map_pairwise!(0, box, cell_list, | ||||||
| parallel = parallel !== false) do x, y, i, j, d2, output | ||||||
| # Skip all indices not in `points` | ||||||
| i in points || return output | ||||||
|  | ||||||
| pos_diff = x - y | ||||||
| distance = sqrt(d2) | ||||||
|  | ||||||
| # When `points_equal_neighbors == true`, a `CellList` is used. | ||||||
| # With a `CellList`, we only see each pair once and have to use symmetry manually. | ||||||
| @inline f(i, j, pos_diff, distance) | ||||||
| @inline f(j, i, -pos_diff, distance) | ||||||
|  | ||||||
| return output | ||||||
| end | ||||||
|  | ||||||
| # With a `CellList`, interaction of a particle with itself is not included | ||||||
|         
                  efaulhaber marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||||||
| PointNeighbors.@threaded system_coords for point in points | ||||||
| zero_pos_diff = zero(PointNeighbors.SVector{ndims(neighborhood_search), | ||||||
| eltype(system_coords)}) | ||||||
| @inline f(point, point, zero_pos_diff, zero(eltype(system_coords))) | ||||||
| end | ||||||
|  | ||||||
| return nothing | ||||||
| end | ||||||
|  | ||||||
| function PointNeighbors.copy_neighborhood_search(nhs::CellListMapNeighborhoodSearch{<:CellListPair}, | ||||||
| search_radius, n_points; | ||||||
| eachpoint = 1:n_points) | ||||||
| return PointNeighbors.CellListMapNeighborhoodSearch(ndims(nhs); search_radius, | ||||||
| points_equal_neighbors = false) | ||||||
| end | ||||||
|  | ||||||
| function PointNeighbors.copy_neighborhood_search(nhs::CellListMapNeighborhoodSearch{<:CellList}, | ||||||
| search_radius, n_points; | ||||||
| eachpoint = 1:n_points) | ||||||
| return PointNeighbors.CellListMapNeighborhoodSearch(ndims(nhs); search_radius, | ||||||
| points_equal_neighbors = true) | ||||||
| end | ||||||
|  | ||||||
| end | ||||||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,12 +1,14 @@ | ||
| [deps] | ||
| BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | ||
| CellListMap = "69e1c6dd-3888-40e6-b3c8-31ac5f578864" | ||
| Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" | ||
| Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
| Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
| TrixiParticles = "66699cd8-9c01-4e9d-a059-b96c86d16b3a" | ||
|  | ||
| [compat] | ||
| BenchmarkTools = "1" | ||
| CellListMap = "0.9" | ||
| Plots = "1" | ||
| Test = "1" | ||
| TrixiParticles = "0.2" | 
Uh oh!
There was an error while loading. Please reload this page.