Skip to content

Commit 6f2f0d2

Browse files
committed
Add continuous output model construct
- Implemented structure to store the coefficients to evaluate the output at any value and added method in Dopri5 to use it. - Added integration tests. - Updated the `solution_output` function to ensure that the last point is added if it's within floating point error of x_end.
1 parent 8bbc437 commit 6f2f0d2

File tree

7 files changed

+403
-27
lines changed

7 files changed

+403
-27
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "ode_solvers"
3-
version = "0.5.0"
3+
version = "0.6.0"
44
authors = ["Sylvain Renevey <syl.renevey@gmail.com>"]
55
description = "Numerical methods to solve ordinary differential equations (ODEs) in Rust."
66
edition = "2021"

README.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ To start using the crate in a project, the following dependency must be added in
1414

1515
```rust
1616
[dependencies]
17-
ode_solvers = "0.5.0"
17+
ode_solvers = "0.6.0"
1818
```
1919

2020
Then, in the main file, add
@@ -98,4 +98,17 @@ let x_out = stepper.x_out();
9898
let y_out = stepper.y_out();
9999
```
100100

101+
## Continuous Output Model
102+
A continuous output model can be built when solving a system with Dopri5
103+
```rust
104+
let mut stepper = Dopri5::new(system, x0, x_end, dx, y0, rtol, atol);
105+
let mut continuous_output_model = ContinuousOutputModel::default();
106+
let res = stepper.integrate_with_continuous_output_model(&mut continuous_output_model);
107+
```
108+
The continuous output model can then be used to evaluate the solution at any point in the integration interval by calling the `evaluate(x: T)` method
109+
```rust
110+
let value = continuous_output_model.evaluate(1.2);
111+
```
112+
This method returns an `Option<State>` which is `None` if the interrogation point is outside the integration interval. The continuous output model is serializable which allows it to be saved and loaded independently of the system definition and integration process.
113+
101114
See the [homepage](https://srenevey.github.io/ode-solvers/) for more details.

src/continuous_output_model.rs

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
use crate::dop_shared::FloatNumber;
2+
use nalgebra::allocator::Allocator;
3+
use nalgebra::{DefaultAllocator, Dim, OVector};
4+
use serde::{Deserialize, Serialize};
5+
6+
/// Stores the coefficients used to compute the dense output.
7+
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
8+
pub struct ContinuousOutputModel<T, V> {
9+
lower_bound: T,
10+
breakpoints: Vec<T>,
11+
intervals: Vec<Interval<T, V>>,
12+
}
13+
14+
impl<T, D: Dim> ContinuousOutputModel<T, OVector<T, D>>
15+
where
16+
f64: From<T>,
17+
T: FloatNumber,
18+
OVector<T, D>: std::ops::Mul<T, Output = OVector<T, D>>,
19+
DefaultAllocator: Allocator<D>,
20+
{
21+
/// Creates a new empty [`ContinuousOutputModel`].
22+
pub fn new() -> Self {
23+
Self {
24+
lower_bound: T::zero(),
25+
breakpoints: Vec::new(),
26+
intervals: Vec::new(),
27+
}
28+
}
29+
30+
/// Evaluates the continuous output at the given value.
31+
pub fn evaluate(&self, x: T) -> Option<OVector<T, D>> {
32+
self.get_interval_index(x).map(|index| {
33+
let lower_bound = if index == 0 {
34+
self.lower_bound
35+
} else {
36+
self.breakpoints[index - 1]
37+
};
38+
let interval = &self.intervals[index];
39+
let theta = (x - lower_bound) / interval.step_size;
40+
let theta1 = T::one() - theta;
41+
42+
let coefficients = &interval.coefficients;
43+
let mut result = coefficients[coefficients.len() - 1].clone();
44+
for i in (0..coefficients.len() - 1).rev() {
45+
let multiplier = if i == coefficients.len() - 1 {
46+
T::one()
47+
} else if i % 2 == 0 {
48+
theta
49+
} else {
50+
theta1
51+
};
52+
result = &coefficients[i] + result * multiplier;
53+
}
54+
Some(result)
55+
})?
56+
}
57+
58+
/// Returns the lower and upper bounds of the continuous output model validity range.
59+
pub fn bounds(&self) -> (T, T) {
60+
(
61+
self.lower_bound,
62+
self.breakpoints[self.breakpoints.len() - 1],
63+
)
64+
}
65+
66+
/// Sets the lower bound of the continuous output.
67+
pub(crate) fn set_lower_bound(&mut self, lower_bound: T) {
68+
self.lower_bound = lower_bound;
69+
}
70+
71+
/// Adds the coefficients valid up to a given breakpoint to the continuous output model.
72+
pub(crate) fn add_interval(
73+
&mut self,
74+
breakpoint: T,
75+
coefficients: Vec<OVector<T, D>>,
76+
step_size: T,
77+
) {
78+
debug_assert!(
79+
self.breakpoints
80+
.last()
81+
.map_or(true, |&last| breakpoint > last),
82+
"Breakpoints must be added in ascending order"
83+
);
84+
self.breakpoints.push(breakpoint);
85+
self.intervals.push(Interval {
86+
coefficients,
87+
step_size,
88+
});
89+
}
90+
91+
/// Returns the index of the interval containing x.
92+
fn get_interval_index(&self, x: T) -> Option<usize> {
93+
if x < self.lower_bound {
94+
return None;
95+
}
96+
97+
match self
98+
.breakpoints
99+
.binary_search_by(|probe| probe.partial_cmp(&x).unwrap())
100+
{
101+
Ok(index) => Some(index),
102+
Err(index) => {
103+
if index < self.intervals.len() {
104+
Some(index)
105+
} else {
106+
None
107+
}
108+
}
109+
}
110+
}
111+
}
112+
113+
#[derive(Debug, Clone, Serialize, Deserialize)]
114+
struct Interval<T, V> {
115+
coefficients: Vec<V>,
116+
step_size: T,
117+
}
118+
119+
#[cfg(test)]
120+
mod tests {
121+
use crate::continuous_output_model::ContinuousOutputModel;
122+
use approx::assert_relative_eq;
123+
use nalgebra::Vector1;
124+
125+
type State = Vector1<f64>;
126+
127+
#[test]
128+
fn test_evaluate_with_odd_number_of_coefficients() {
129+
let first_breakpoint = 2.0;
130+
let coefficients_first_interval = vec![State::new(0.2), State::new(1.0), State::new(-2.5)];
131+
let first_step_size = 0.1;
132+
let second_breakpoint = 5.0;
133+
let coefficients_second_interval = vec![State::new(-1.5), State::new(0.4), State::new(1.2)];
134+
let second_step_size = 0.2;
135+
136+
let mut continuous_output_model = ContinuousOutputModel::default();
137+
continuous_output_model.add_interval(
138+
first_breakpoint,
139+
coefficients_first_interval,
140+
first_step_size,
141+
);
142+
continuous_output_model.add_interval(
143+
second_breakpoint,
144+
coefficients_second_interval,
145+
second_step_size,
146+
);
147+
148+
assert_eq!(continuous_output_model.evaluate(-0.1), None);
149+
assert_eq!(continuous_output_model.evaluate(5.1), None);
150+
151+
assert_relative_eq!(
152+
continuous_output_model.evaluate(0.0).unwrap(),
153+
State::new(0.2),
154+
epsilon = 1e-9
155+
);
156+
assert_relative_eq!(
157+
continuous_output_model.evaluate(1.2).unwrap(),
158+
State::new(342.2),
159+
epsilon = 1e-9
160+
);
161+
assert_relative_eq!(
162+
continuous_output_model.evaluate(2.0).unwrap(),
163+
State::new(970.2),
164+
epsilon = 1e-9
165+
);
166+
assert_relative_eq!(
167+
continuous_output_model.evaluate(2.5).unwrap(),
168+
State::new(-5.0),
169+
epsilon = 1e-9
170+
);
171+
assert_relative_eq!(
172+
continuous_output_model.evaluate(5.0).unwrap(),
173+
State::new(-247.5),
174+
epsilon = 1e-9
175+
);
176+
}
177+
178+
#[test]
179+
fn test_evaluate_with_even_number_of_coefficients() {
180+
let first_breakpoint = 2.0;
181+
let coefficients_first_interval = vec![
182+
State::new(0.2),
183+
State::new(1.0),
184+
State::new(-2.5),
185+
State::new(0.3),
186+
];
187+
let first_step_size = 0.1;
188+
let second_breakpoint = 5.0;
189+
let coefficients_second_interval = vec![
190+
State::new(-1.5),
191+
State::new(0.4),
192+
State::new(1.2),
193+
State::new(2.7),
194+
];
195+
let second_step_size = 0.2;
196+
197+
let mut continuous_output_model = ContinuousOutputModel::default();
198+
continuous_output_model.add_interval(
199+
first_breakpoint,
200+
coefficients_first_interval,
201+
first_step_size,
202+
);
203+
continuous_output_model.add_interval(
204+
second_breakpoint,
205+
coefficients_second_interval,
206+
second_step_size,
207+
);
208+
209+
assert_relative_eq!(
210+
continuous_output_model.evaluate(0.0).unwrap(),
211+
State::new(0.2),
212+
epsilon = 1e-9
213+
);
214+
assert_relative_eq!(
215+
continuous_output_model.evaluate(1.2).unwrap(),
216+
State::new(-133.0),
217+
epsilon = 1e-9
218+
);
219+
assert_relative_eq!(
220+
continuous_output_model.evaluate(2.0).unwrap(),
221+
State::new(-1309.8),
222+
epsilon = 1e-9
223+
);
224+
assert_relative_eq!(
225+
continuous_output_model.evaluate(2.5).unwrap(),
226+
State::new(-30.3125),
227+
epsilon = 1e-9
228+
);
229+
assert_relative_eq!(
230+
continuous_output_model.evaluate(5.0).unwrap(),
231+
State::new(-8752.5),
232+
epsilon = 1e-9
233+
);
234+
}
235+
236+
#[test]
237+
fn test_evaluate_with_no_coefficients() {
238+
let continuous_output_model: ContinuousOutputModel<f64, State> =
239+
ContinuousOutputModel::default();
240+
assert_eq!(continuous_output_model.evaluate(0.0), None);
241+
assert_eq!(continuous_output_model.evaluate(3.0), None);
242+
}
243+
244+
#[test]
245+
fn test_get_interval_index() {
246+
let mut continuous_output_model: ContinuousOutputModel<f64, State> =
247+
ContinuousOutputModel::default();
248+
continuous_output_model.add_interval(2.0, vec![], 0.1);
249+
continuous_output_model.add_interval(5.0, vec![], 0.1);
250+
251+
assert_eq!(continuous_output_model.get_interval_index(-0.001), None);
252+
assert_eq!(continuous_output_model.get_interval_index(0.0), Some(0));
253+
assert_eq!(continuous_output_model.get_interval_index(1.3), Some(0));
254+
assert_eq!(continuous_output_model.get_interval_index(2.0), Some(0));
255+
assert_eq!(continuous_output_model.get_interval_index(3.2), Some(1));
256+
assert_eq!(continuous_output_model.get_interval_index(5.0), Some(1));
257+
assert_eq!(continuous_output_model.get_interval_index(5.0001), None);
258+
}
259+
}

0 commit comments

Comments
 (0)