Skip to content

Commit 8bbc437

Browse files
committed
Add final point of the integration interval if within rounding error
- Updated the `solution_output` function to ensure that the last point is added if it's within floating point error of x_end. - Moved the `sign` function to dop_shared. - Fixed typos.
1 parent 20301ed commit 8bbc437

File tree

2 files changed

+42
-35
lines changed

2 files changed

+42
-35
lines changed

src/dop853.rs

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::butcher_tableau::dopri853;
44
use crate::controller::Controller;
55
use crate::dop_shared::*;
66

7-
use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, OVector};
7+
use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, MatrixSum, OVector, U1};
88

99
trait DefaultController<T: FloatNumber> {
1010
fn default(x: T, x_end: T) -> Self;
@@ -123,9 +123,9 @@ where
123123
/// * `fac_min` - Minimum factor between two successive steps. Default is 0.333
124124
/// * `fac_max` - Maximum factor between two successive steps. Default is 6.0
125125
/// * `h_max` - Maximum step size. Default is `x_end-x
126-
/// * `h` - Initial value of the step size. If h = 0.0, the intial value of h is computed automatically
126+
/// * `h` - Initial value of the step size. If h = 0.0, the initial value of h is computed automatically
127127
/// * `n_max` - Maximum number of iterations. Default is 100000
128-
/// * `n_stiff` - Stifness is tested when the number of iterations is a multiple of n_stiff. Default is 1000
128+
/// * `n_stiff` - Stiffness is tested when the number of iterations is a multiple of n_stiff. Default is 1000
129129
/// * `out_type` - Type of the output. Must be a variant of the OutputType enum. Default is Dense
130130
///
131131
#[allow(clippy::too_many_arguments)]
@@ -190,7 +190,7 @@ where
190190
}
191191
}
192192

193-
/// Compute the initial stepsize
193+
/// Computes the initial step size.
194194
fn hinit(&self) -> T {
195195
let (rows, cols) = self.y.shape_generic();
196196
let mut f0 = OVector::zeros_generic(rows, cols);
@@ -367,7 +367,7 @@ where
367367
self.f.system(self.x + self.h, &y_tmp, &mut k[3]);
368368
self.stats.num_eval += 1;
369369

370-
// Stifness detection
370+
// Stiffness detection
371371
if self.stats.accepted_steps % self.n_stiff == 0 || iasti > 0 {
372372
let num = T::from((&k[3] - &k[2]).dot(&(&k[3] - &k[2]))).unwrap();
373373
let den = T::from((&k[4] - &y_next).dot(&(&k[4] - &y_next))).unwrap();
@@ -521,30 +521,42 @@ where
521521
if self.x_old.abs() <= self.xd.abs() && self.x.abs() >= self.xd.abs() {
522522
let theta = (self.xd - self.x_old) / self.h_old;
523523
let theta1 = T::one() - theta;
524-
525-
let y_out = &self.rcont[0]
526-
+ (&self.rcont[1]
527-
+ (&self.rcont[2]
528-
+ (&self.rcont[3]
529-
+ (&self.rcont[4]
530-
+ (&self.rcont[5]
531-
+ (&self.rcont[6] + &self.rcont[7] * theta)
532-
* theta1)
533-
* theta)
534-
* theta1)
535-
* theta)
536-
* theta1)
537-
* theta;
524+
let y_out = self.compute_y_out(theta, theta1);
538525
self.results.push(self.xd, y_out);
539526
self.xd += self.dx;
540527
}
541528
}
529+
530+
// Ensure the last point is added if it's within floating point error of x_end.
531+
if (self.xd - self.x_end).abs() < T::from(1e-9).unwrap() {
532+
let theta = (self.x_end - self.x_old) / self.h_old;
533+
let theta1 = T::one() - theta;
534+
let y_out = self.compute_y_out(theta, theta1);
535+
self.results.push(self.x_end, y_out);
536+
self.xd += self.dx;
537+
}
542538
}
543539
} else {
544-
self.results.push(self.x0, y_next);
540+
self.results.push(self.x0, y_next.clone());
545541
}
546542
}
547543

544+
/// Computes the value of y for given theta and theta1 values.
545+
fn compute_y_out(&mut self, theta: T, theta1: T) -> MatrixSum<T, D, U1, D, U1> {
546+
&self.rcont[0]
547+
+ (&self.rcont[1]
548+
+ (&self.rcont[2]
549+
+ (&self.rcont[3]
550+
+ (&self.rcont[4]
551+
+ (&self.rcont[5]
552+
+ (&self.rcont[6] + &self.rcont[7] * theta) * theta1)
553+
* theta)
554+
* theta1)
555+
* theta)
556+
* theta1)
557+
* theta
558+
}
559+
548560
/// Getter for the independent variable's output.
549561
pub fn x_out(&self) -> &Vec<T> {
550562
self.results.get().0
@@ -572,14 +584,6 @@ where
572584
}
573585
}
574586

575-
fn sign<T: FloatNumber>(a: T, b: T) -> T {
576-
if b > T::zero() {
577-
a.abs()
578-
} else {
579-
-a.abs()
580-
}
581-
}
582-
583587
#[cfg(test)]
584588
mod tests {
585589
use super::*;

src/dop_shared.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use thiserror::Error;
1414
/// The type parameter T should be either `f32` or `f64`, the trait [FloatNumber] is used
1515
/// internally to allow generic code.
1616
///
17-
/// The type parameter V is a state vector. To have an easy start it is recommend to use [nalgebra] vectors.
17+
/// The type parameter V is a state vector. To have an easy start it is recommended to use [nalgebra] vectors.
1818
///
1919
/// ```
2020
/// use ode_solvers::{System, SVector, Vector3};
@@ -34,6 +34,7 @@ where
3434
{
3535
/// System of ordinary differential equations.
3636
fn system(&self, x: T, y: &V, dy: &mut V);
37+
3738
/// Stop function called at every successful integration step. The integration is stopped when this function returns true.
3839
fn solout(&mut self, _x: T, _y: &V, _dy: &V) -> bool {
3940
false
@@ -141,12 +142,6 @@ impl Stats {
141142
rejected_steps: 0,
142143
}
143144
}
144-
145-
/// Prints some statistics related to the integration process.
146-
#[deprecated(since = "0.2.0", note = "Use std::fmt::Display instead")]
147-
pub fn print(&self) {
148-
println!("{}", self);
149-
}
150145
}
151146

152147
impl fmt::Display for Stats {
@@ -156,3 +151,11 @@ impl fmt::Display for Stats {
156151
write!(f, "Number of rejected steps: {}", self.rejected_steps)
157152
}
158153
}
154+
155+
pub(crate) fn sign<T: FloatNumber>(a: T, b: T) -> T {
156+
if b > T::zero() {
157+
a.abs()
158+
} else {
159+
-a.abs()
160+
}
161+
}

0 commit comments

Comments
 (0)