Skip to content

Commit d881e19

Browse files
committed
Add continuous output type, fix unit test, add integration test
1 parent 7469f6b commit d881e19

File tree

7 files changed

+118
-19
lines changed

7 files changed

+118
-19
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "ode_solvers"
3-
version = "0.6.0"
3+
version = "0.6.1"
44
authors = ["Sylvain Renevey <[email protected]>"]
55
description = "Numerical methods to solve ordinary differential equations (ODEs) in Rust."
66
edition = "2021"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ To start using the crate in a project, the following dependency must be added in
1414

1515
```rust
1616
[dependencies]
17-
ode_solvers = "0.6.0"
17+
ode_solvers = "0.6.1"
1818
```
1919

2020
Then, in the main file, add

src/continuous_output_model.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,6 @@ where
7575
coefficients: Vec<OVector<T, D>>,
7676
step_size: T,
7777
) {
78-
debug_assert!(
79-
self.breakpoints
80-
.last()
81-
.map_or(true, |&last| breakpoint > last),
82-
"Breakpoints must be added in ascending order"
83-
);
8478
self.breakpoints.push(breakpoint);
8579
self.intervals.push(Interval {
8680
coefficients,

src/dop_shared.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ impl<T, V> Default for SolverResult<T, V> {
113113
pub enum OutputType {
114114
Dense,
115115
Sparse,
116+
Continuous,
116117
}
117118

118119
/// Enumeration of the errors that may arise during integration.

src/dopri5.rs

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,15 @@ where
246246
&mut self,
247247
continuous_output: &mut ContinuousOutputModel<T, OVector<T, D>>,
248248
) -> Result<Stats, IntegrationError> {
249+
self.out_type = OutputType::Continuous;
249250
self.integrate_core(Some(continuous_output))
250251
}
251252

252253
/// Integrates the system.
253254
pub fn integrate(&mut self) -> Result<Stats, IntegrationError> {
255+
if self.out_type == OutputType::Continuous {
256+
panic!("Please use `integrate_with_continuous_output_model` to compute a continuous output model.");
257+
}
254258
self.integrate_core(None)
255259
}
256260

@@ -334,8 +338,8 @@ where
334338
k[1] = k[6].clone();
335339
self.stats.num_eval += 6;
336340

337-
// Prepare dense output
338-
if self.out_type == OutputType::Dense {
341+
// Prepare dense/continuous output
342+
if self.out_type == OutputType::Dense || self.out_type == OutputType::Continuous {
339343
self.rcont[4] = (&k[0] * dopri54::d::<T>(1)
340344
+ &k[2] * dopri54::d::<T>(3)
341345
+ &k[3] * dopri54::d::<T>(4)
@@ -396,8 +400,8 @@ where
396400
}
397401
}
398402

399-
// Prepare dense output
400-
if self.out_type == OutputType::Dense {
403+
// Prepare dense/continuous output
404+
if self.out_type == OutputType::Dense || self.out_type == OutputType::Continuous {
401405
let h = self.h;
402406

403407
let ydiff = &y_next - &self.y;
@@ -414,7 +418,7 @@ where
414418
self.x += self.h;
415419
self.h_old = self.h;
416420

417-
self.solution_output(y_next, &mut continuous_output_model);
421+
self.solution_output(y_next, &mut continuous_output_model, &k[0]);
418422

419423
if self
420424
.f
@@ -443,6 +447,7 @@ where
443447
&mut self,
444448
y_next: OVector<T, D>,
445449
continuous_output_model: &mut Option<&mut ContinuousOutputModel<T, OVector<T, D>>>,
450+
dy: &OVector<T, D>,
446451
) {
447452
if self.out_type == OutputType::Dense {
448453
while self.xd.abs() <= self.x.abs() {
@@ -453,6 +458,13 @@ where
453458
self.results.push(self.xd, y_out);
454459
self.xd += self.dx;
455460
}
461+
462+
if self
463+
.f
464+
.solout(self.xd, self.results.get().1.last().unwrap(), dy)
465+
{
466+
break;
467+
}
456468
}
457469

458470
// Ensure the last point is added if it's within floating point error of x_end.
@@ -514,7 +526,7 @@ mod tests {
514526
use crate::{OVector, System, Vector1};
515527
use nalgebra::{allocator::Allocator, DefaultAllocator, Dim};
516528

517-
// Same as Test3 from rk4.rs, but aborts after x is greater/equal than 0.5
529+
// Same as Test3 from rk4.rs, but aborts after x is equal to or greater than 0.5
518530
struct Test1 {}
519531
impl<D: Dim> System<f64, OVector<f64, D>> for Test1
520532
where
@@ -536,9 +548,34 @@ mod tests {
536548
let _ = stepper.integrate();
537549

538550
let x = stepper.x_out();
539-
assert!((*x.last().unwrap() - 0.5).abs() < 1.0E-9); //
551+
assert!((*x.last().unwrap() - 0.5).abs() < 1.0E-9);
540552

541553
let out = stepper.y_out();
542-
assert!((&out[5][0] - 0.913059243).abs() < 1.0E-9);
554+
assert!((&out[5][0] - 0.9130611474392001).abs() < 1.0E-9);
555+
}
556+
557+
#[test]
558+
#[should_panic]
559+
fn test_integrate_when_continuous_output_type_panic() {
560+
let system = Test1 {};
561+
let mut stepper = Dopri5::from_param(
562+
system,
563+
0.,
564+
1.,
565+
0.1,
566+
Vector1::new(1.),
567+
1e-12,
568+
1e-6,
569+
0.1,
570+
0.2,
571+
0.3,
572+
2.0,
573+
5.0,
574+
0.1,
575+
10000,
576+
1000,
577+
OutputType::Continuous,
578+
);
579+
let _ = stepper.integrate();
543580
}
544581
}

tests/dopri5.rs

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ use nalgebra::Vector1;
33
use ode_solvers::continuous_output_model::ContinuousOutputModel;
44
use ode_solvers::{Dopri5, System};
55

6+
const INTEGRATION_TOLERANCE: f64 = 1e-10;
7+
const STEP_SIZE: f64 = 0.1;
8+
69
type State = Vector1<f64>;
710
type Time = f64;
811

@@ -15,13 +18,33 @@ impl System<f64, State> for Ode {
1518
}
1619
}
1720

21+
struct OdeWithStoppingCondition;
22+
23+
impl System<f64, State> for OdeWithStoppingCondition {
24+
fn system(&self, t: Time, y: &State, dy: &mut State) {
25+
dy[0] = -3.0 * t * t * y[0].exp();
26+
}
27+
28+
fn solout(&mut self, x: f64, _y: &State, _dy: &State) -> bool {
29+
x >= 4.1
30+
}
31+
}
32+
1833
#[test]
1934
fn test_dopri5() {
2035
let f = Ode {};
2136
let initial_value = State::new(0.0);
2237
let x_start = 0.0;
2338
let x_end = 6.0;
24-
let mut stepper = Dopri5::new(f, x_start, x_end, 0.1, initial_value, 1e-10, 1e-10);
39+
let mut stepper = Dopri5::new(
40+
f,
41+
x_start,
42+
x_end,
43+
STEP_SIZE,
44+
initial_value,
45+
INTEGRATION_TOLERANCE,
46+
INTEGRATION_TOLERANCE,
47+
);
2548
let _ = stepper.integrate();
2649

2750
let (times_s, states) = stepper.results().get();
@@ -43,7 +66,15 @@ fn test_dopri5_with_continuous_output_model() {
4366
let initial_value = State::new(0.0);
4467
let x_start = 0.0;
4568
let x_end = 6.0;
46-
let mut stepper = Dopri5::new(f, x_start, x_end, 0.1, initial_value, 1e-10, 1e-10);
69+
let mut stepper = Dopri5::new(
70+
f,
71+
x_start,
72+
x_end,
73+
STEP_SIZE,
74+
initial_value,
75+
INTEGRATION_TOLERANCE,
76+
INTEGRATION_TOLERANCE,
77+
);
4778
let mut continuous_output_model = ContinuousOutputModel::default();
4879
let _ = stepper.integrate_with_continuous_output_model(&mut continuous_output_model);
4980

@@ -63,6 +94,42 @@ fn test_dopri5_with_continuous_output_model() {
6394
}
6495
}
6596

97+
#[test]
98+
fn test_dopri5_with_continuous_output_model_and_stopping_condition() {
99+
let f = OdeWithStoppingCondition {};
100+
let initial_value = State::new(0.0);
101+
let x_start = 0.0;
102+
let x_end = 6.0;
103+
let mut stepper = Dopri5::new(
104+
f,
105+
x_start,
106+
x_end,
107+
STEP_SIZE,
108+
initial_value,
109+
INTEGRATION_TOLERANCE,
110+
INTEGRATION_TOLERANCE,
111+
);
112+
let mut continuous_output_model = ContinuousOutputModel::default();
113+
let _ = stepper.integrate_with_continuous_output_model(&mut continuous_output_model);
114+
115+
assert!(continuous_output_model.evaluate(-0.0001).is_none());
116+
assert!(continuous_output_model.evaluate(0.0).is_some());
117+
assert!(continuous_output_model.evaluate(4.1).is_some());
118+
assert!(continuous_output_model.evaluate(4.1 + STEP_SIZE).is_none());
119+
assert!(continuous_output_model.bounds().1 >= 4.1);
120+
assert!(continuous_output_model.bounds().1 < 4.1 + STEP_SIZE);
121+
122+
let num_samples = 100;
123+
let step_size = (x_end - 4.1) / num_samples as f64;
124+
for i in 0..num_samples {
125+
let t = i as f64 * step_size + x_start;
126+
let continuous_output = continuous_output_model.evaluate(t);
127+
let analytic = analytic_solution(t);
128+
assert!(continuous_output.is_some());
129+
assert_relative_eq!(continuous_output.unwrap(), analytic, epsilon = 1e-9);
130+
}
131+
}
132+
66133
/// Evaluates the analytic solution of the ODE at the given time.
67134
fn analytic_solution(t: Time) -> State {
68135
State::new(-(t.powi(3) + 1.0).ln())

0 commit comments

Comments
 (0)