Skip to content

Commit 269704d

Browse files
authored
SYCL Track Finding in the Example Executables, main branch (2025.01.09.) (acts-project#812)
* Introduced a SYCL measurement sorting algorithm. * Taught the SYCL sequence example about track finding. * Taught the SYCL full chain algorithm about track finding. * Update the README to the current code status.
1 parent e7948f3 commit 269704d

File tree

7 files changed

+415
-147
lines changed

7 files changed

+415
-147
lines changed

README.md

Lines changed: 131 additions & 120 deletions
Large diffs are not rendered by default.

device/sycl/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ traccc_add_library( traccc_sycl sycl TYPE SHARED
1616
# Clusterization algorithm(s).
1717
"include/traccc/sycl/clusterization/clusterization_algorithm.hpp"
1818
"src/clusterization/clusterization_algorithm.sycl"
19+
"include/traccc/sycl/clusterization/measurement_sorting_algorithm.hpp"
20+
"src/clusterization/measurement_sorting_algorithm.sycl"
1921
# Seeding algorithm(s).
2022
"include/traccc/sycl/seeding/silicon_pixel_spacepoint_formation_algorithm.hpp"
2123
"src/seeding/silicon_pixel_spacepoint_formation_algorithm.cpp"
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
// Local include(s).
11+
#include "traccc/sycl/utils/queue_wrapper.hpp"
12+
13+
// Project include(s).
14+
#include "traccc/edm/measurement.hpp"
15+
#include "traccc/utils/algorithm.hpp"
16+
17+
// VecMem include(s).
18+
#include <vecmem/utils/copy.hpp>
19+
20+
// System include(s).
21+
#include <functional>
22+
23+
namespace traccc::sycl {
24+
25+
/// Algorithm sorting the reconstructed measurements in their container
26+
///
27+
/// The track finding algorithm expects measurements belonging to a single
28+
/// detector module to be consecutive in memory. But
29+
/// @c traccc::sycl::clusterization_algorithm does not (currently) produce the
30+
/// measurements in such an ordered state. This is where this algorithm comes
31+
/// to the rescue.
32+
///
33+
class measurement_sorting_algorithm
34+
: public algorithm<measurement_collection_types::view(
35+
const measurement_collection_types::view&)> {
36+
37+
public:
38+
/// Constructor for the algorithm
39+
///
40+
/// @param copy The copy object to use in the algorithm
41+
/// @param queue Wrapper for the for the SYCL queue for kernel invocation
42+
///
43+
measurement_sorting_algorithm(vecmem::copy& copy, queue_wrapper& queue);
44+
45+
/// Callable operator performing the sorting on a container
46+
///
47+
/// @param measurements The measurements to sort
48+
///
49+
output_type operator()(const measurement_collection_types::view&
50+
measurements_view) const override;
51+
52+
private:
53+
/// Copy object to use in the algorithm
54+
std::reference_wrapper<vecmem::copy> m_copy;
55+
/// The SYCL queue to use
56+
std::reference_wrapper<queue_wrapper> m_queue;
57+
58+
}; // class measurement_sorting_algorithm
59+
60+
} // namespace traccc::sycl
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
// Local include(s).
9+
#include "../utils/get_queue.hpp"
10+
#include "traccc/sycl/clusterization/measurement_sorting_algorithm.hpp"
11+
12+
// oneDPL include(s).
13+
#include <oneapi/dpl/algorithm>
14+
#include <oneapi/dpl/execution>
15+
16+
// SYCL include(s).
17+
#include <sycl/sycl.hpp>
18+
19+
namespace traccc::sycl {
20+
21+
measurement_sorting_algorithm::measurement_sorting_algorithm(
22+
vecmem::copy& copy, queue_wrapper& queue)
23+
: m_copy{copy}, m_queue{queue} {}
24+
25+
measurement_sorting_algorithm::output_type
26+
measurement_sorting_algorithm::operator()(
27+
const measurement_collection_types::view& measurements_view) const {
28+
29+
// Get the SYCL queue to use for the algorithm.
30+
::sycl::queue& queue = details::get_queue(m_queue.get());
31+
32+
// oneDPL policy to use, forcing execution onto the same device that the
33+
// hand-written kernels would run on.
34+
auto policy = oneapi::dpl::execution::device_policy{queue};
35+
36+
// Get the number of measurements. This is necessary because the input
37+
// container may not be fixed sized. And we can't give invalid pointers /
38+
// iterators to oneDPL.
39+
const measurement_collection_types::view::size_type n_measurements =
40+
m_copy.get().get_size(measurements_view);
41+
42+
// Sort the measurements in place
43+
oneapi::dpl::sort(policy, measurements_view.ptr(),
44+
measurements_view.ptr() + n_measurements,
45+
measurement_sort_comp());
46+
queue.wait_and_throw();
47+
48+
// Return the view of the sorted measurements.
49+
return measurements_view;
50+
}
51+
52+
} // namespace traccc::sycl

examples/run/sycl/full_chain_algorithm.hpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/** TRACCC library, part of the ACTS project (R&D line)
22
*
3-
* (c) 2022-2024 CERN for the benefit of the ACTS project
3+
* (c) 2022-2025 CERN for the benefit of the ACTS project
44
*
55
* Mozilla Public License Version 2.0
66
*/
@@ -9,11 +9,12 @@
99

1010
// Project include(s).
1111
#include "traccc/edm/silicon_cell_collection.hpp"
12-
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
1312
#include "traccc/fitting/kalman_fitting_algorithm.hpp"
1413
#include "traccc/geometry/detector.hpp"
1514
#include "traccc/geometry/silicon_detector_description.hpp"
1615
#include "traccc/sycl/clusterization/clusterization_algorithm.hpp"
16+
#include "traccc/sycl/clusterization/measurement_sorting_algorithm.hpp"
17+
#include "traccc/sycl/finding/combinatorial_kalman_filter_algorithm.hpp"
1718
#include "traccc/sycl/seeding/seeding_algorithm.hpp"
1819
#include "traccc/sycl/seeding/silicon_pixel_spacepoint_formation_algorithm.hpp"
1920
#include "traccc/sycl/seeding/track_params_estimation.hpp"
@@ -72,7 +73,7 @@ class full_chain_algorithm
7273
using clustering_algorithm = clusterization_algorithm;
7374
/// Track finding algorithm type
7475
using finding_algorithm =
75-
traccc::host::combinatorial_kalman_filter_algorithm;
76+
traccc::sycl::combinatorial_kalman_filter_algorithm;
7677
/// Track fitting algorithm type
7778
using fitting_algorithm = traccc::host::kalman_fitting_algorithm;
7879

@@ -126,6 +127,11 @@ class full_chain_algorithm
126127
/// Memory copy object
127128
mutable vecmem::sycl::async_copy m_copy;
128129

130+
/// Constant B field for the (seed) track parameter estimation
131+
traccc::vector3 m_field_vec;
132+
/// Constant B field for the track finding and fitting
133+
detray::bfield::const_field_t m_field;
134+
129135
/// Detector description
130136
std::reference_wrapper<const silicon_detector_description::host>
131137
m_det_descr;
@@ -140,21 +146,39 @@ class full_chain_algorithm
140146

141147
/// @name Sub-algorithms used by this full-chain algorithm
142148
/// @{
149+
143150
/// Clusterization algorithm
144151
clusterization_algorithm m_clusterization;
152+
/// Measurement sorting algorithm
153+
measurement_sorting_algorithm m_measurement_sorting;
145154
/// Spacepoint formation algorithm
146155
spacepoint_formation_algorithm m_spacepoint_formation;
147156
/// Seeding algorithm
148157
seeding_algorithm m_seeding;
149158
/// Track parameter estimation algorithm
150159
track_params_estimation m_track_parameter_estimation;
160+
/// Track finding algorithm
161+
finding_algorithm m_finding;
151162

152-
/// Configs
163+
/// @}
164+
165+
/// @}
166+
167+
/// @name Algorithm configurations
168+
/// @{
169+
170+
/// Configuration for clustering
153171
clustering_config m_clustering_config;
172+
/// Configuration for the seed finding
154173
seedfinder_config m_finder_config;
174+
/// Configuration for the spacepoint grid formation
155175
spacepoint_grid_config m_grid_config;
176+
/// Configuration for the seed filtering
156177
seedfilter_config m_filter_config;
157178

179+
/// Configuration for the track finding
180+
finding_algorithm::config_type m_finding_config;
181+
158182
/// @}
159183

160184
}; // class full_chain_algorithm

examples/run/sycl/full_chain_algorithm.sycl

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ full_chain_algorithm::full_chain_algorithm(
5454
const seedfinder_config& finder_config,
5555
const spacepoint_grid_config& grid_config,
5656
const seedfilter_config& filter_config,
57-
const finding_algorithm::config_type&,
57+
const finding_algorithm::config_type& finding_config,
5858
const fitting_algorithm::config_type&,
5959
const silicon_detector_description::host& det_descr,
6060
host_detector_type* detector)
@@ -64,6 +64,8 @@ full_chain_algorithm::full_chain_algorithm(
6464
m_device_mr{&(m_data->m_queue)},
6565
m_cached_device_mr{m_device_mr},
6666
m_copy{&(m_data->m_queue)},
67+
m_field_vec{0.f, 0.f, finder_config.bFieldInZ},
68+
m_field{detray::bfield::create_const_field(m_field_vec)},
6769
m_det_descr(det_descr),
6870
m_device_det_descr{
6971
static_cast<silicon_detector_description::buffer::size_type>(
@@ -73,6 +75,7 @@ full_chain_algorithm::full_chain_algorithm(
7375
m_device_detector{},
7476
m_clusterization{memory_resource{m_cached_device_mr, &(m_host_mr.get())},
7577
m_copy, m_data->m_queue_wrapper, clustering_config},
78+
m_measurement_sorting(m_copy, m_data->m_queue_wrapper),
7679
m_spacepoint_formation{
7780
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
7881
m_data->m_queue_wrapper},
@@ -85,10 +88,14 @@ full_chain_algorithm::full_chain_algorithm(
8588
m_track_parameter_estimation{
8689
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
8790
m_data->m_queue_wrapper},
91+
m_finding{finding_config,
92+
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
93+
m_data->m_queue_wrapper},
8894
m_clustering_config(clustering_config),
8995
m_finder_config(finder_config),
9096
m_grid_config(grid_config),
91-
m_filter_config(filter_config) {
97+
m_filter_config(filter_config),
98+
m_finding_config(finding_config) {
9299

93100
// Tell the user what device is being used.
94101
std::cout
@@ -112,6 +119,8 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
112119
m_device_mr{&(m_data->m_queue)},
113120
m_cached_device_mr{m_device_mr},
114121
m_copy{&(m_data->m_queue)},
122+
m_field_vec{parent.m_field_vec},
123+
m_field{parent.m_field},
115124
m_det_descr(parent.m_det_descr),
116125
m_device_det_descr{
117126
static_cast<silicon_detector_description::buffer::size_type>(
@@ -122,6 +131,7 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
122131
m_clusterization{memory_resource{m_cached_device_mr, &(m_host_mr.get())},
123132
m_copy, m_data->m_queue_wrapper,
124133
parent.m_clustering_config},
134+
m_measurement_sorting(m_copy, m_data->m_queue_wrapper),
125135
m_spacepoint_formation{
126136
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
127137
m_data->m_queue_wrapper},
@@ -134,10 +144,14 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
134144
m_track_parameter_estimation{
135145
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
136146
m_data->m_queue_wrapper},
147+
m_finding{parent.m_finding_config,
148+
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
149+
m_data->m_queue_wrapper},
137150
m_clustering_config(parent.m_clustering_config),
138151
m_finder_config(parent.m_finder_config),
139152
m_grid_config(parent.m_grid_config),
140-
m_filter_config(parent.m_filter_config) {
153+
m_filter_config(parent.m_filter_config),
154+
m_finding_config(parent.m_finding_config) {
141155

142156
// Copy the detector (description) to the device.
143157
m_copy(vecmem::get_data(m_det_descr.get()), m_device_det_descr)->wait();
@@ -161,21 +175,27 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
161175
// Execute the algorithms.
162176
const clusterization_algorithm::output_type measurements =
163177
m_clusterization(cells_buffer, m_device_det_descr);
178+
m_measurement_sorting(measurements);
164179

165180
// If we have a Detray detector, run the seeding, track
166181
// finding and fitting.
167182
if (m_detector != nullptr) {
168183

184+
// Run the seed-finding.
169185
const spacepoint_formation_algorithm::output_type spacepoints =
170186
m_spacepoint_formation(m_device_detector_view, measurements);
171187
const track_params_estimation::output_type track_params =
172188
m_track_parameter_estimation(spacepoints, m_seeding(spacepoints),
173189
{0.f, 0.f, m_finder_config.bFieldInZ});
174190

191+
// Run the track finding.
192+
const finding_algorithm::output_type track_candidates = m_finding(
193+
m_device_detector_view, m_field, measurements, track_params);
194+
175195
// Get the final data back to the host.
176196
bound_track_parameters_collection_types::host result(
177197
&(m_host_mr.get()));
178-
m_copy(track_params, result)->wait();
198+
m_copy(track_candidates.headers, result)->wait();
179199

180200
// Return the host container.
181201
return result;

0 commit comments

Comments
 (0)