Skip to content

Commit 726978a

Browse files
authored
Sort tracks per measurement for fast search (acts-project#1110)
1 parent 5c0f1ea commit 726978a

File tree

6 files changed

+135
-3
lines changed

6 files changed

+135
-3
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
// Project include(s).
11+
#include "traccc/definitions/primitives.hpp"
12+
13+
// VecMem include(s).
14+
#include <vecmem/containers/data/jagged_vector_view.hpp>
15+
#include <vecmem/containers/data/vector_view.hpp>
16+
17+
// System include(s).
18+
#include <cstddef>
19+
20+
namespace traccc::device {
21+
22+
/// (Event Data) Payload for the @c traccc::device::sort_tracks_per_measurement
23+
/// function
24+
struct sort_tracks_per_measurement_payload {
25+
26+
/**
27+
* @brief View object to the tracks per measurement
28+
*/
29+
vecmem::data::jagged_vector_view<unsigned int> tracks_per_measurement_view;
30+
};
31+
32+
} // namespace traccc::device

device/cuda/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ traccc_add_library( traccc_cuda cuda TYPE SHARED
113113
"src/ambiguity_resolution/kernels/reset_status.cuh"
114114
"src/ambiguity_resolution/kernels/scan_block_offsets.cu"
115115
"src/ambiguity_resolution/kernels/scan_block_offsets.cuh"
116+
"src/ambiguity_resolution/kernels/sort_tracks_per_measurement.cu"
117+
"src/ambiguity_resolution/kernels/sort_tracks_per_measurement.cuh"
116118
"src/ambiguity_resolution/kernels/sort_updated_tracks.cu"
117119
"src/ambiguity_resolution/kernels/sort_updated_tracks.cuh"
118120
"src/ambiguity_resolution/kernels/remove_tracks.cu"

device/cuda/src/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.cu

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "./kernels/remove_tracks.cuh"
2424
#include "./kernels/reset_status.cuh"
2525
#include "./kernels/scan_block_offsets.cuh"
26+
#include "./kernels/sort_tracks_per_measurement.cuh"
2627
#include "./kernels/sort_updated_tracks.cuh"
2728
#include "traccc/cuda/ambiguity_resolution/greedy_ambiguity_resolution_algorithm.hpp"
2829

@@ -296,6 +297,22 @@ greedy_ambiguity_resolution_algorithm::operator()(
296297
m_stream.get().synchronize();
297298
}
298299

300+
// Sort tracks per measurement vector
301+
// @TODO: For the case where the measurement is shared by more than 1024
302+
// tracks, the tracks need to be sorted again using thrust::sort
303+
{
304+
const unsigned int nThreads = 1024;
305+
const unsigned int nBlocks = meas_count;
306+
307+
kernels::sort_tracks_per_measurement<<<nBlocks, nThreads, 0, stream>>>(
308+
device::sort_tracks_per_measurement_payload{
309+
.tracks_per_measurement_view = tracks_per_measurement_buffer,
310+
});
311+
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
312+
313+
m_stream.get().synchronize();
314+
}
315+
299316
// Make shared number of measurements vector
300317
vecmem::data::vector_buffer<unsigned int> n_shared_buffer{n_tracks,
301318
m_mr.main};

device/cuda/src/ambiguity_resolution/kernels/remove_tracks.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ __launch_bounds__(512) __global__
106106
sorted_ids.at(n_accepted_prev - 1 - sh_threads[threadIndex]);
107107

108108
unsigned int worst_idx =
109-
thrust::find(thrust::seq, tracks.begin(), tracks.end(), trk_id) -
109+
thrust::lower_bound(thrust::seq, tracks.begin(), tracks.end(),
110+
trk_id) -
110111
tracks.begin();
111112

112113
track_status[worst_idx] = 0;
@@ -120,8 +121,8 @@ __launch_bounds__(512) __global__
120121

121122
trk_id = sorted_ids[n_accepted_prev - 1 - sh_threads[i]];
122123

123-
worst_idx = thrust::find(thrust::seq, tracks.begin(),
124-
tracks.end(), trk_id) -
124+
worst_idx = thrust::lower_bound(thrust::seq, tracks.begin(),
125+
tracks.end(), trk_id) -
125126
tracks.begin();
126127

127128
track_status[worst_idx] = 0;
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
// Local include(s).
9+
#include "../../utils/global_index.hpp"
10+
#include "sort_tracks_per_measurement.cuh"
11+
12+
// VecMem include(s).
13+
#include <vecmem/containers/jagged_device_vector.hpp>
14+
15+
namespace traccc::cuda::kernels {
16+
17+
__global__ void sort_tracks_per_measurement(
18+
device::sort_tracks_per_measurement_payload payload) {
19+
20+
__shared__ unsigned int sh_trk_ids[1024];
21+
22+
vecmem::jagged_device_vector<unsigned int> tracks_per_measurement(
23+
payload.tracks_per_measurement_view);
24+
25+
auto tracks = tracks_per_measurement.at(blockIdx.x);
26+
const unsigned int tid = threadIdx.x;
27+
const unsigned int n_tracks = tracks.size();
28+
29+
sh_trk_ids[tid] = std::numeric_limits<unsigned int>::max();
30+
31+
if (tid < n_tracks) {
32+
sh_trk_ids[tid] = tracks[tid];
33+
}
34+
35+
// Bitonic sort
36+
const unsigned int N = 1 << (32 - __clz(n_tracks - 1));
37+
for (int k = 2; k <= N; k <<= 1) {
38+
39+
bool ascending = ((tid & k) == 0);
40+
41+
for (int j = k >> 1; j > 0; j >>= 1) {
42+
int ixj = tid ^ j;
43+
44+
if (ixj > tid && ixj < N && tid < N) {
45+
auto trk_i = sh_trk_ids[tid];
46+
auto trk_j = sh_trk_ids[ixj];
47+
48+
bool should_swap = (trk_i > trk_j) == ascending;
49+
50+
if (should_swap) {
51+
sh_trk_ids[tid] = trk_j;
52+
sh_trk_ids[ixj] = trk_i;
53+
}
54+
}
55+
__syncthreads();
56+
}
57+
}
58+
59+
if (tid < n_tracks) {
60+
tracks[tid] = sh_trk_ids[tid];
61+
}
62+
}
63+
} // namespace traccc::cuda::kernels
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
// Project include(s).
11+
#include "traccc/ambiguity_resolution/device/sort_tracks_per_measurement.hpp"
12+
13+
namespace traccc::cuda::kernels {
14+
15+
__global__ void sort_tracks_per_measurement(
16+
device::sort_tracks_per_measurement_payload payload);
17+
}

0 commit comments

Comments
 (0)