Skip to content

Commit 8620769

Browse files
authored
Add the two-filters method for Kalman Smoothing (acts-project#788)
1 parent c06d483 commit 8620769

16 files changed

+453
-78
lines changed

core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ traccc_add_library( traccc_core core TYPE SHARED
7474
"include/traccc/fitting/kalman_filter/kalman_fitter.hpp"
7575
"include/traccc/fitting/kalman_filter/kalman_step_aborter.hpp"
7676
"include/traccc/fitting/kalman_filter/statistics_updater.hpp"
77+
"include/traccc/fitting/kalman_filter/two_filters_smoother.hpp"
7778
"include/traccc/fitting/details/fit_tracks.hpp"
7879
"include/traccc/fitting/kalman_fitting_algorithm.hpp"
7980
"src/fitting/kalman_fitting_algorithm.cpp"

core/include/traccc/edm/track_parameters.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,20 @@ inline void wrap_phi(bound_track_parameters& param) {
4848
param.set_phi(phi);
4949
}
5050

51+
/// Covariance inflation used for track fitting
52+
TRACCC_HOST_DEVICE
53+
inline void inflate_covariance(bound_track_parameters& param,
54+
const traccc::scalar inf_fac) {
55+
auto& cov = param.covariance();
56+
for (unsigned int i = 0; i < e_bound_size; i++) {
57+
for (unsigned int j = 0; j < e_bound_size; j++) {
58+
if (i == j) {
59+
getter::element(cov, i, i) *= inf_fac;
60+
} else {
61+
getter::element(cov, i, j) = 0.f;
62+
}
63+
}
64+
}
65+
}
66+
5167
} // namespace traccc

core/include/traccc/edm/track_state.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ struct fitting_result {
3737
// The number of holes (The number of sensitive surfaces which do not have a
3838
// measurement for the track pattern)
3939
unsigned int n_holes{0u};
40+
41+
/// Reset the statistics
42+
TRACCC_HOST_DEVICE
43+
void reset_statistics() {
44+
ndf = 0.f;
45+
chi2 = 0.f;
46+
n_holes = 0u;
47+
}
4048
};
4149

4250
/// Fitting result per measurement
@@ -160,6 +168,14 @@ struct track_state {
160168
TRACCC_HOST_DEVICE
161169
inline const scalar_type& filtered_chi2() const { return m_filtered_chi2; }
162170

171+
/// @return the non-const chi square of backward filter
172+
TRACCC_HOST_DEVICE
173+
inline scalar_type& backward_chi2() { return m_backward_chi2; }
174+
175+
/// @return the const chi square of backward filter
176+
TRACCC_HOST_DEVICE
177+
inline scalar_type backward_chi2() const { return m_backward_chi2; }
178+
163179
/// @return the non-const filtered parameter
164180
TRACCC_HOST_DEVICE
165181
inline bound_track_parameters_type& filtered() { return m_filtered; }
@@ -200,6 +216,7 @@ struct track_state {
200216
bound_track_parameters_type m_filtered;
201217
scalar_type m_smoothed_chi2 = 0.f;
202218
bound_track_parameters_type m_smoothed;
219+
scalar_type m_backward_chi2 = 0.f;
203220
};
204221

205222
/// Declare all track_state collection types

core/include/traccc/fitting/fitting_config.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ struct fitting_config {
2525
/// Particle hypothesis
2626
detray::pdg_particle<traccc::scalar> ptc_hypothesis =
2727
detray::muon<traccc::scalar>();
28+
29+
/// Smoothing with backward filter
30+
bool use_backward_filter = false;
31+
traccc::scalar covariance_inflation_factor = 1e3f;
2832
};
2933

3034
} // namespace traccc

core/include/traccc/fitting/kalman_filter/kalman_actor.hpp

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "traccc/definitions/qualifiers.hpp"
1212
#include "traccc/edm/track_state.hpp"
1313
#include "traccc/fitting/kalman_filter/gain_matrix_updater.hpp"
14+
#include "traccc/fitting/kalman_filter/two_filters_smoother.hpp"
1415
#include "traccc/utils/particle.hpp"
1516

1617
// detray include(s).
@@ -33,31 +34,50 @@ struct kalman_actor : detray::actor {
3334
state(vector_t<track_state_type>&& track_states)
3435
: m_track_states(std::move(track_states)) {
3536
m_it = m_track_states.begin();
37+
m_it_rev = m_track_states.rbegin();
3638
}
3739

3840
/// Constructor with the vector of track states
3941
TRACCC_HOST_DEVICE
4042
state(const vector_t<track_state_type>& track_states)
4143
: m_track_states(track_states) {
4244
m_it = m_track_states.begin();
45+
m_it_rev = m_track_states.rbegin();
4346
}
4447

4548
/// @return the reference of track state pointed by the iterator
4649
TRACCC_HOST_DEVICE
47-
track_state_type& operator()() { return *m_it; }
50+
track_state_type& operator()() {
51+
if (!backward_mode) {
52+
return *m_it;
53+
} else {
54+
return *m_it_rev;
55+
}
56+
}
4857

4958
/// Reset the iterator
5059
TRACCC_HOST_DEVICE
51-
void reset() { m_it = m_track_states.begin(); }
60+
void reset() {
61+
m_it = m_track_states.begin();
62+
m_it_rev = m_track_states.rbegin();
63+
}
5264

5365
/// Advance the iterator
5466
TRACCC_HOST_DEVICE
55-
void next() { m_it++; }
67+
void next() {
68+
if (!backward_mode) {
69+
m_it++;
70+
} else {
71+
m_it_rev++;
72+
}
73+
}
5674

5775
/// @return true if the iterator reaches the end of vector
5876
TRACCC_HOST_DEVICE
59-
bool is_complete() const {
60-
if (m_it == m_track_states.end()) {
77+
bool is_complete() {
78+
if (!backward_mode && m_it == m_track_states.end()) {
79+
return true;
80+
} else if (backward_mode && m_it_rev == m_track_states.rend()) {
6181
return true;
6282
}
6383
return false;
@@ -69,9 +89,15 @@ struct kalman_actor : detray::actor {
6989
// iterator for forward filtering
7090
typename vector_t<track_state_type>::iterator m_it;
7191

92+
// iterator for backward filtering
93+
typename vector_t<track_state_type>::reverse_iterator m_it_rev;
94+
7295
// The number of holes (The number of sensitive surfaces which do not
7396
// have a measurement for the track pattern)
7497
unsigned int n_holes{0u};
98+
99+
// Run back filtering for smoothing, if true
100+
bool backward_mode = false;
75101
};
76102

77103
/// Actor operation to perform the Kalman filtering
@@ -99,32 +125,44 @@ struct kalman_actor : detray::actor {
99125
// Increase the hole counts if the propagator fails to find the next
100126
// measurement
101127
if (navigation.barcode() != trk_state.surface_link()) {
102-
actor_state.n_holes++;
128+
if (!actor_state.backward_mode) {
129+
actor_state.n_holes++;
130+
}
103131
return;
104132
}
105133

106134
// This track state is not a hole
107-
trk_state.is_hole = false;
135+
if (!actor_state.backward_mode) {
136+
trk_state.is_hole = false;
137+
}
108138

109139
// Run Kalman Gain Updater
110140
const auto sf = navigation.get_surface();
111141

112-
const bool res =
113-
sf.template visit_mask<gain_matrix_updater<algebra_t>>(
142+
bool res = false;
143+
144+
if (!actor_state.backward_mode) {
145+
// Forward filter
146+
res = sf.template visit_mask<gain_matrix_updater<algebra_t>>(
147+
trk_state, propagation._stepping.bound_params());
148+
149+
// Update the propagation flow
150+
stepping.bound_params() = trk_state.filtered();
151+
152+
// Set full jacobian
153+
trk_state.jacobian() = stepping.full_jacobian();
154+
} else {
155+
// Backward filter for smoothing
156+
res = sf.template visit_mask<two_filters_smoother<algebra_t>>(
114157
trk_state, propagation._stepping.bound_params());
158+
}
115159

116160
// Abort if the Kalman update fails
117161
if (!res) {
118162
propagation._heartbeat &= navigation.abort();
119163
return;
120164
}
121165

122-
// Update the propagation flow
123-
stepping.bound_params() = trk_state.filtered();
124-
125-
// Set full jacobian
126-
trk_state.jacobian() = stepping.full_jacobian();
127-
128166
// Change the charge of hypothesized particles when the sign of qop
129167
// is changed (This rarely happens when qop is set with a poor seed
130168
// resolution)

core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "traccc/fitting/kalman_filter/kalman_actor.hpp"
1818
#include "traccc/fitting/kalman_filter/kalman_step_aborter.hpp"
1919
#include "traccc/fitting/kalman_filter/statistics_updater.hpp"
20+
#include "traccc/fitting/kalman_filter/two_filters_smoother.hpp"
2021
#include "traccc/utils/particle.hpp"
2122

2223
// detray include(s).
@@ -67,10 +68,17 @@ class kalman_fitter {
6768
detray::actor_chain<detray::dtuple, aborter, transporter, interactor,
6869
fit_actor, resetter, kalman_step_aborter>;
6970

71+
using backward_actor_chain_type =
72+
detray::actor_chain<detray::dtuple, aborter, transporter, fit_actor,
73+
interactor, resetter, kalman_step_aborter>;
74+
7075
// Propagator type
7176
using propagator_type =
7277
detray::propagator<stepper_t, navigator_t, actor_chain_type>;
7378

79+
using backward_propagator_type =
80+
detray::propagator<stepper_t, navigator_t, backward_actor_chain_type>;
81+
7482
/// Constructor with a detector
7583
///
7684
/// @param det the detector object
@@ -104,6 +112,14 @@ class kalman_fitter {
104112
m_resetter_state, m_step_aborter_state);
105113
}
106114

115+
/// @return the actor chain state
116+
TRACCC_HOST_DEVICE
117+
typename backward_actor_chain_type::state backward_actor_state() {
118+
return detray::tie(m_aborter_state, m_transporter_state,
119+
m_fit_actor_state, m_interactor_state,
120+
m_resetter_state, m_step_aborter_state);
121+
}
122+
107123
/// Individual actor states
108124
typename aborter::state m_aborter_state{};
109125
typename transporter::state m_transporter_state{};
@@ -132,17 +148,15 @@ class kalman_fitter {
132148
// Reset the iterator of kalman actor
133149
fitter_state.m_fit_actor_state.reset();
134150

135-
if (i == 0) {
136-
filter(seed_params, fitter_state);
137-
}
138-
// From the second iteration, seed parameter is the smoothed track
139-
// parameter at the first surface
140-
else {
141-
const auto& new_seed_params =
142-
fitter_state.m_fit_actor_state.m_track_states[0].smoothed();
151+
auto seed_params_cpy =
152+
(i == 0) ? seed_params
153+
: fitter_state.m_fit_actor_state.m_track_states[0]
154+
.smoothed();
143155

144-
filter(new_seed_params, fitter_state);
145-
}
156+
inflate_covariance(seed_params_cpy,
157+
m_cfg.covariance_inflation_factor);
158+
159+
filter(seed_params_cpy, fitter_state);
146160
}
147161
}
148162

@@ -178,6 +192,9 @@ class kalman_fitter {
178192
.template set_constraint<detray::step::constraint::e_accuracy>(
179193
m_cfg.propagation.stepping.step_constraint);
180194

195+
// Reset fitter statistics
196+
fitter_state.m_fit_res.reset_statistics();
197+
181198
// Run forward filtering
182199
propagator.propagate(propagation, fitter_state());
183200

@@ -194,14 +211,10 @@ class kalman_fitter {
194211
/// track and vertex fitting", R.Frühwirth, NIM A.
195212
///
196213
/// @param fitter_state the state of kalman fitter
197-
TRACCC_HOST_DEVICE
198-
void smooth(state& fitter_state) {
214+
TRACCC_HOST_DEVICE void smooth(state& fitter_state) {
215+
199216
auto& track_states = fitter_state.m_fit_actor_state.m_track_states;
200217

201-
// The smoothing algorithm requires the following:
202-
// (1) the filtered track parameter of the current surface
203-
// (2) the smoothed track parameter of the next surface
204-
//
205218
// Since the smoothed track parameter of the last surface can be
206219
// considered to be the filtered one, we can reversly iterate the
207220
// algorithm to obtain the smoothed parameter of other surfaces
@@ -210,14 +223,45 @@ class kalman_fitter {
210223
last.smoothed().set_covariance(last.filtered().covariance());
211224
last.smoothed_chi2() = last.filtered_chi2();
212225

213-
for (typename vector_type<track_state<algebra_type>>::reverse_iterator
214-
it = track_states.rbegin() + 1;
215-
it != track_states.rend(); ++it) {
226+
if (m_cfg.use_backward_filter) {
227+
// Backward propagator for the two-filters method
228+
backward_propagator_type propagator(m_cfg.propagation);
229+
230+
// Set path limit
231+
fitter_state.m_aborter_state.set_path_limit(
232+
m_cfg.propagation.stepping.path_limit);
233+
234+
typename backward_propagator_type::state propagation(
235+
last.smoothed(), m_field, m_detector);
236+
237+
inflate_covariance(propagation._stepping.bound_params(),
238+
m_cfg.covariance_inflation_factor);
239+
240+
propagation._navigation.set_volume(
241+
last.smoothed().surface_link().volume());
216242

217-
// Run kalman smoother
218-
const detray::tracking_surface sf{m_detector, it->surface_link()};
219-
sf.template visit_mask<gain_matrix_smoother<algebra_type>>(
220-
*it, *(it - 1));
243+
propagation._navigation.set_direction(
244+
detray::navigation::direction::e_backward);
245+
fitter_state.m_fit_actor_state.backward_mode = true;
246+
247+
propagator.propagate(propagation,
248+
fitter_state.backward_actor_state());
249+
250+
// Reset the backward mode to false
251+
fitter_state.m_fit_actor_state.backward_mode = false;
252+
253+
} else {
254+
// Run the Rauch–Tung–Striebel (RTS) smoother
255+
for (typename vector_type<
256+
track_state<algebra_type>>::reverse_iterator it =
257+
track_states.rbegin() + 1;
258+
it != track_states.rend(); ++it) {
259+
260+
const detray::tracking_surface sf{m_detector,
261+
it->surface_link()};
262+
sf.template visit_mask<gain_matrix_smoother<algebra_type>>(
263+
*it, *(it - 1));
264+
}
221265
}
222266
}
223267

@@ -233,8 +277,8 @@ class kalman_fitter {
233277

234278
const detray::tracking_surface sf{m_detector,
235279
trk_state.surface_link()};
236-
sf.template visit_mask<statistics_updater<algebra_type>>(fit_res,
237-
trk_state);
280+
sf.template visit_mask<statistics_updater<algebra_type>>(
281+
fit_res, trk_state, m_cfg.use_backward_filter);
238282
}
239283

240284
// Subtract the NDoF with the degree of freedom of the bound track (=5)

core/include/traccc/fitting/kalman_filter/statistics_updater.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ struct statistics_updater {
3030
TRACCC_HOST_DEVICE inline void operator()(
3131
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
3232
fitting_result<algebra_t>& fit_res,
33-
const track_state<algebra_t>& trk_state) {
33+
const track_state<algebra_t>& trk_state,
34+
const bool use_backward_filter) {
3435

3536
if (!trk_state.is_hole) {
3637

@@ -41,7 +42,11 @@ struct statistics_updater {
4142
fit_res.ndf += static_cast<scalar_type>(D);
4243

4344
// total_chi2 = total_chi2 + chi2
44-
fit_res.chi2 += trk_state.smoothed_chi2();
45+
if (use_backward_filter) {
46+
fit_res.chi2 += trk_state.backward_chi2();
47+
} else {
48+
fit_res.chi2 += trk_state.filtered_chi2();
49+
}
4550
}
4651
}
4752
};

0 commit comments

Comments
 (0)