@@ -4,7 +4,7 @@ use crate::butcher_tableau::dopri853;
44use crate :: controller:: Controller ;
55use crate :: dop_shared:: * ;
66
7- use nalgebra:: { allocator:: Allocator , DefaultAllocator , Dim , OVector } ;
7+ use nalgebra:: { allocator:: Allocator , DefaultAllocator , Dim , MatrixSum , OVector , U1 } ;
88
99trait DefaultController < T : FloatNumber > {
1010 fn default ( x : T , x_end : T ) -> Self ;
@@ -123,9 +123,9 @@ where
123123 /// * `fac_min` - Minimum factor between two successive steps. Default is 0.333
124124 /// * `fac_max` - Maximum factor between two successive steps. Default is 6.0
125125 /// * `h_max` - Maximum step size. Default is `x_end-x
126- /// * `h` - Initial value of the step size. If h = 0.0, the intial value of h is computed automatically
126+ /// * `h` - Initial value of the step size. If h = 0.0, the initial value of h is computed automatically
127127 /// * `n_max` - Maximum number of iterations. Default is 100000
128- /// * `n_stiff` - Stifness is tested when the number of iterations is a multiple of n_stiff. Default is 1000
128+ /// * `n_stiff` - Stiffness is tested when the number of iterations is a multiple of n_stiff. Default is 1000
129129 /// * `out_type` - Type of the output. Must be a variant of the OutputType enum. Default is Dense
130130 ///
131131 #[ allow( clippy:: too_many_arguments) ]
@@ -190,7 +190,7 @@ where
190190 }
191191 }
192192
193- /// Compute the initial stepsize
193+ /// Computes the initial step size.
194194 fn hinit ( & self ) -> T {
195195 let ( rows, cols) = self . y . shape_generic ( ) ;
196196 let mut f0 = OVector :: zeros_generic ( rows, cols) ;
@@ -367,7 +367,7 @@ where
367367 self . f . system ( self . x + self . h , & y_tmp, & mut k[ 3 ] ) ;
368368 self . stats . num_eval += 1 ;
369369
370- // Stifness detection
370+ // Stiffness detection
371371 if self . stats . accepted_steps % self . n_stiff == 0 || iasti > 0 {
372372 let num = T :: from ( ( & k[ 3 ] - & k[ 2 ] ) . dot ( & ( & k[ 3 ] - & k[ 2 ] ) ) ) . unwrap ( ) ;
373373 let den = T :: from ( ( & k[ 4 ] - & y_next) . dot ( & ( & k[ 4 ] - & y_next) ) ) . unwrap ( ) ;
@@ -521,30 +521,42 @@ where
521521 if self . x_old . abs ( ) <= self . xd . abs ( ) && self . x . abs ( ) >= self . xd . abs ( ) {
522522 let theta = ( self . xd - self . x_old ) / self . h_old ;
523523 let theta1 = T :: one ( ) - theta;
524-
525- let y_out = & self . rcont [ 0 ]
526- + ( & self . rcont [ 1 ]
527- + ( & self . rcont [ 2 ]
528- + ( & self . rcont [ 3 ]
529- + ( & self . rcont [ 4 ]
530- + ( & self . rcont [ 5 ]
531- + ( & self . rcont [ 6 ] + & self . rcont [ 7 ] * theta)
532- * theta1)
533- * theta)
534- * theta1)
535- * theta)
536- * theta1)
537- * theta;
524+ let y_out = self . compute_y_out ( theta, theta1) ;
538525 self . results . push ( self . xd , y_out) ;
539526 self . xd += self . dx ;
540527 }
541528 }
529+
530+ // Ensure the last point is added if it's within floating point error of x_end.
531+ if ( self . xd - self . x_end ) . abs ( ) < T :: from ( 1e-9 ) . unwrap ( ) {
532+ let theta = ( self . x_end - self . x_old ) / self . h_old ;
533+ let theta1 = T :: one ( ) - theta;
534+ let y_out = self . compute_y_out ( theta, theta1) ;
535+ self . results . push ( self . x_end , y_out) ;
536+ self . xd += self . dx ;
537+ }
542538 }
543539 } else {
544- self . results . push ( self . x0 , y_next) ;
540+ self . results . push ( self . x0 , y_next. clone ( ) ) ;
545541 }
546542 }
547543
544+ /// Computes the value of y for given theta and theta1 values.
545+ fn compute_y_out ( & mut self , theta : T , theta1 : T ) -> MatrixSum < T , D , U1 , D , U1 > {
546+ & self . rcont [ 0 ]
547+ + ( & self . rcont [ 1 ]
548+ + ( & self . rcont [ 2 ]
549+ + ( & self . rcont [ 3 ]
550+ + ( & self . rcont [ 4 ]
551+ + ( & self . rcont [ 5 ]
552+ + ( & self . rcont [ 6 ] + & self . rcont [ 7 ] * theta) * theta1)
553+ * theta)
554+ * theta1)
555+ * theta)
556+ * theta1)
557+ * theta
558+ }
559+
548560 /// Getter for the independent variable's output.
549561 pub fn x_out ( & self ) -> & Vec < T > {
550562 self . results . get ( ) . 0
@@ -572,14 +584,6 @@ where
572584 }
573585}
574586
575- fn sign < T : FloatNumber > ( a : T , b : T ) -> T {
576- if b > T :: zero ( ) {
577- a. abs ( )
578- } else {
579- -a. abs ( )
580- }
581- }
582-
583587#[ cfg( test) ]
584588mod tests {
585589 use super :: * ;
0 commit comments