|
| 1 | +use crate::dop_shared::FloatNumber; |
| 2 | +use nalgebra::allocator::Allocator; |
| 3 | +use nalgebra::{DefaultAllocator, Dim, OVector}; |
| 4 | +use serde::{Deserialize, Serialize}; |
| 5 | + |
| 6 | +/// Stores the coefficients used to compute the dense output. |
| 7 | +#[derive(Debug, Clone, Default, Serialize, Deserialize)] |
| 8 | +pub struct ContinuousOutputModel<T, V> { |
| 9 | + lower_bound: T, |
| 10 | + breakpoints: Vec<T>, |
| 11 | + intervals: Vec<Interval<T, V>>, |
| 12 | +} |
| 13 | + |
| 14 | +impl<T, D: Dim> ContinuousOutputModel<T, OVector<T, D>> |
| 15 | +where |
| 16 | + f64: From<T>, |
| 17 | + T: FloatNumber, |
| 18 | + OVector<T, D>: std::ops::Mul<T, Output = OVector<T, D>>, |
| 19 | + DefaultAllocator: Allocator<D>, |
| 20 | +{ |
| 21 | + /// Creates a new empty [`ContinuousOutputModel`]. |
| 22 | + pub fn new() -> Self { |
| 23 | + Self { |
| 24 | + lower_bound: T::zero(), |
| 25 | + breakpoints: Vec::new(), |
| 26 | + intervals: Vec::new(), |
| 27 | + } |
| 28 | + } |
| 29 | + |
| 30 | + /// Evaluates the continuous output at the given value. |
| 31 | + pub fn evaluate(&self, x: T) -> Option<OVector<T, D>> { |
| 32 | + self.get_interval_index(x).map(|index| { |
| 33 | + let lower_bound = if index == 0 { |
| 34 | + self.lower_bound |
| 35 | + } else { |
| 36 | + self.breakpoints[index - 1] |
| 37 | + }; |
| 38 | + let interval = &self.intervals[index]; |
| 39 | + let theta = (x - lower_bound) / interval.step_size; |
| 40 | + let theta1 = T::one() - theta; |
| 41 | + |
| 42 | + let coefficients = &interval.coefficients; |
| 43 | + let mut result = coefficients[coefficients.len() - 1].clone(); |
| 44 | + for i in (0..coefficients.len() - 1).rev() { |
| 45 | + let multiplier = if i == coefficients.len() - 1 { |
| 46 | + T::one() |
| 47 | + } else if i % 2 == 0 { |
| 48 | + theta |
| 49 | + } else { |
| 50 | + theta1 |
| 51 | + }; |
| 52 | + result = &coefficients[i] + result * multiplier; |
| 53 | + } |
| 54 | + Some(result) |
| 55 | + })? |
| 56 | + } |
| 57 | + |
| 58 | + /// Returns the lower and upper bounds of the continuous output model validity range. |
| 59 | + pub fn bounds(&self) -> (T, T) { |
| 60 | + ( |
| 61 | + self.lower_bound, |
| 62 | + self.breakpoints[self.breakpoints.len() - 1], |
| 63 | + ) |
| 64 | + } |
| 65 | + |
| 66 | + /// Sets the lower bound of the continuous output. |
| 67 | + pub(crate) fn set_lower_bound(&mut self, lower_bound: T) { |
| 68 | + self.lower_bound = lower_bound; |
| 69 | + } |
| 70 | + |
| 71 | + /// Adds the coefficients valid up to a given breakpoint to the continuous output model. |
| 72 | + pub(crate) fn add_interval( |
| 73 | + &mut self, |
| 74 | + breakpoint: T, |
| 75 | + coefficients: Vec<OVector<T, D>>, |
| 76 | + step_size: T, |
| 77 | + ) { |
| 78 | + debug_assert!( |
| 79 | + self.breakpoints |
| 80 | + .last() |
| 81 | + .map_or(true, |&last| breakpoint > last), |
| 82 | + "Breakpoints must be added in ascending order" |
| 83 | + ); |
| 84 | + self.breakpoints.push(breakpoint); |
| 85 | + self.intervals.push(Interval { |
| 86 | + coefficients, |
| 87 | + step_size, |
| 88 | + }); |
| 89 | + } |
| 90 | + |
| 91 | + /// Returns the index of the interval containing x. |
| 92 | + fn get_interval_index(&self, x: T) -> Option<usize> { |
| 93 | + if x < self.lower_bound { |
| 94 | + return None; |
| 95 | + } |
| 96 | + |
| 97 | + match self |
| 98 | + .breakpoints |
| 99 | + .binary_search_by(|probe| probe.partial_cmp(&x).unwrap()) |
| 100 | + { |
| 101 | + Ok(index) => Some(index), |
| 102 | + Err(index) => { |
| 103 | + if index < self.intervals.len() { |
| 104 | + Some(index) |
| 105 | + } else { |
| 106 | + None |
| 107 | + } |
| 108 | + } |
| 109 | + } |
| 110 | + } |
| 111 | +} |
| 112 | + |
| 113 | +#[derive(Debug, Clone, Serialize, Deserialize)] |
| 114 | +struct Interval<T, V> { |
| 115 | + coefficients: Vec<V>, |
| 116 | + step_size: T, |
| 117 | +} |
| 118 | + |
| 119 | +#[cfg(test)] |
| 120 | +mod tests { |
| 121 | + use crate::continuous_output_model::ContinuousOutputModel; |
| 122 | + use approx::assert_relative_eq; |
| 123 | + use nalgebra::Vector1; |
| 124 | + |
| 125 | + type State = Vector1<f64>; |
| 126 | + |
| 127 | + #[test] |
| 128 | + fn test_evaluate_with_odd_number_of_coefficients() { |
| 129 | + let first_breakpoint = 2.0; |
| 130 | + let coefficients_first_interval = vec![State::new(0.2), State::new(1.0), State::new(-2.5)]; |
| 131 | + let first_step_size = 0.1; |
| 132 | + let second_breakpoint = 5.0; |
| 133 | + let coefficients_second_interval = vec![State::new(-1.5), State::new(0.4), State::new(1.2)]; |
| 134 | + let second_step_size = 0.2; |
| 135 | + |
| 136 | + let mut continuous_output_model = ContinuousOutputModel::default(); |
| 137 | + continuous_output_model.add_interval( |
| 138 | + first_breakpoint, |
| 139 | + coefficients_first_interval, |
| 140 | + first_step_size, |
| 141 | + ); |
| 142 | + continuous_output_model.add_interval( |
| 143 | + second_breakpoint, |
| 144 | + coefficients_second_interval, |
| 145 | + second_step_size, |
| 146 | + ); |
| 147 | + |
| 148 | + assert_eq!(continuous_output_model.evaluate(-0.1), None); |
| 149 | + assert_eq!(continuous_output_model.evaluate(5.1), None); |
| 150 | + |
| 151 | + assert_relative_eq!( |
| 152 | + continuous_output_model.evaluate(0.0).unwrap(), |
| 153 | + State::new(0.2), |
| 154 | + epsilon = 1e-9 |
| 155 | + ); |
| 156 | + assert_relative_eq!( |
| 157 | + continuous_output_model.evaluate(1.2).unwrap(), |
| 158 | + State::new(342.2), |
| 159 | + epsilon = 1e-9 |
| 160 | + ); |
| 161 | + assert_relative_eq!( |
| 162 | + continuous_output_model.evaluate(2.0).unwrap(), |
| 163 | + State::new(970.2), |
| 164 | + epsilon = 1e-9 |
| 165 | + ); |
| 166 | + assert_relative_eq!( |
| 167 | + continuous_output_model.evaluate(2.5).unwrap(), |
| 168 | + State::new(-5.0), |
| 169 | + epsilon = 1e-9 |
| 170 | + ); |
| 171 | + assert_relative_eq!( |
| 172 | + continuous_output_model.evaluate(5.0).unwrap(), |
| 173 | + State::new(-247.5), |
| 174 | + epsilon = 1e-9 |
| 175 | + ); |
| 176 | + } |
| 177 | + |
| 178 | + #[test] |
| 179 | + fn test_evaluate_with_even_number_of_coefficients() { |
| 180 | + let first_breakpoint = 2.0; |
| 181 | + let coefficients_first_interval = vec![ |
| 182 | + State::new(0.2), |
| 183 | + State::new(1.0), |
| 184 | + State::new(-2.5), |
| 185 | + State::new(0.3), |
| 186 | + ]; |
| 187 | + let first_step_size = 0.1; |
| 188 | + let second_breakpoint = 5.0; |
| 189 | + let coefficients_second_interval = vec![ |
| 190 | + State::new(-1.5), |
| 191 | + State::new(0.4), |
| 192 | + State::new(1.2), |
| 193 | + State::new(2.7), |
| 194 | + ]; |
| 195 | + let second_step_size = 0.2; |
| 196 | + |
| 197 | + let mut continuous_output_model = ContinuousOutputModel::default(); |
| 198 | + continuous_output_model.add_interval( |
| 199 | + first_breakpoint, |
| 200 | + coefficients_first_interval, |
| 201 | + first_step_size, |
| 202 | + ); |
| 203 | + continuous_output_model.add_interval( |
| 204 | + second_breakpoint, |
| 205 | + coefficients_second_interval, |
| 206 | + second_step_size, |
| 207 | + ); |
| 208 | + |
| 209 | + assert_relative_eq!( |
| 210 | + continuous_output_model.evaluate(0.0).unwrap(), |
| 211 | + State::new(0.2), |
| 212 | + epsilon = 1e-9 |
| 213 | + ); |
| 214 | + assert_relative_eq!( |
| 215 | + continuous_output_model.evaluate(1.2).unwrap(), |
| 216 | + State::new(-133.0), |
| 217 | + epsilon = 1e-9 |
| 218 | + ); |
| 219 | + assert_relative_eq!( |
| 220 | + continuous_output_model.evaluate(2.0).unwrap(), |
| 221 | + State::new(-1309.8), |
| 222 | + epsilon = 1e-9 |
| 223 | + ); |
| 224 | + assert_relative_eq!( |
| 225 | + continuous_output_model.evaluate(2.5).unwrap(), |
| 226 | + State::new(-30.3125), |
| 227 | + epsilon = 1e-9 |
| 228 | + ); |
| 229 | + assert_relative_eq!( |
| 230 | + continuous_output_model.evaluate(5.0).unwrap(), |
| 231 | + State::new(-8752.5), |
| 232 | + epsilon = 1e-9 |
| 233 | + ); |
| 234 | + } |
| 235 | + |
| 236 | + #[test] |
| 237 | + fn test_evaluate_with_no_coefficients() { |
| 238 | + let continuous_output_model: ContinuousOutputModel<f64, State> = |
| 239 | + ContinuousOutputModel::default(); |
| 240 | + assert_eq!(continuous_output_model.evaluate(0.0), None); |
| 241 | + assert_eq!(continuous_output_model.evaluate(3.0), None); |
| 242 | + } |
| 243 | + |
| 244 | + #[test] |
| 245 | + fn test_get_interval_index() { |
| 246 | + let mut continuous_output_model: ContinuousOutputModel<f64, State> = |
| 247 | + ContinuousOutputModel::default(); |
| 248 | + continuous_output_model.add_interval(2.0, vec![], 0.1); |
| 249 | + continuous_output_model.add_interval(5.0, vec![], 0.1); |
| 250 | + |
| 251 | + assert_eq!(continuous_output_model.get_interval_index(-0.001), None); |
| 252 | + assert_eq!(continuous_output_model.get_interval_index(0.0), Some(0)); |
| 253 | + assert_eq!(continuous_output_model.get_interval_index(1.3), Some(0)); |
| 254 | + assert_eq!(continuous_output_model.get_interval_index(2.0), Some(0)); |
| 255 | + assert_eq!(continuous_output_model.get_interval_index(3.2), Some(1)); |
| 256 | + assert_eq!(continuous_output_model.get_interval_index(5.0), Some(1)); |
| 257 | + assert_eq!(continuous_output_model.get_interval_index(5.0001), None); |
| 258 | + } |
| 259 | +} |
0 commit comments