@@ -267,6 +267,9 @@ where
267
267
& mut self ,
268
268
rhs : & mut ArrayBase < D , Ix1 > ,
269
269
) -> Result < LeastSquaresResult < E , Ix1 > > {
270
+ if self . shape ( ) [ 0 ] != rhs. shape ( ) [ 0 ] {
271
+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) . into ( ) ) ;
272
+ }
270
273
let ( m, n) = ( self . shape ( ) [ 0 ] , self . shape ( ) [ 1 ] ) ;
271
274
if n > m {
272
275
// we need a new rhs b/c it will be overwritten with the solution
@@ -285,15 +288,15 @@ fn compute_least_squares_srhs<E, D1, D2>(
285
288
rhs : & mut ArrayBase < D2 , Ix1 > ,
286
289
) -> Result < LeastSquaresResult < E , Ix1 > >
287
290
where
288
- E : Scalar + Lapack + LeastSquaresSvdDivideConquer_ ,
291
+ E : Scalar + Lapack ,
289
292
D1 : DataMut < Elem = E > ,
290
293
D2 : DataMut < Elem = E > ,
291
294
{
292
295
let LeastSquaresOutput :: < E > {
293
296
singular_values,
294
297
rank,
295
298
} = unsafe {
296
- < E as LeastSquaresSvdDivideConquer_ > :: least_squares (
299
+ E :: least_squares (
297
300
a. layout ( ) ?,
298
301
a. as_allocated_mut ( ) ?,
299
302
rhs. as_slice_memory_order_mut ( )
@@ -348,6 +351,9 @@ where
348
351
& mut self ,
349
352
rhs : & mut ArrayBase < D , Ix2 > ,
350
353
) -> Result < LeastSquaresResult < E , Ix2 > > {
354
+ if self . shape ( ) [ 0 ] != rhs. shape ( ) [ 0 ] {
355
+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) . into ( ) ) ;
356
+ }
351
357
let ( m, n) = ( self . shape ( ) [ 0 ] , self . shape ( ) [ 1 ] ) ;
352
358
if n > m {
353
359
// we need a new rhs b/c it will be overwritten with the solution
@@ -550,28 +556,13 @@ mod tests {
550
556
//
551
557
// Testing error cases
552
558
//
553
-
554
559
#[ test]
555
560
fn incompatible_shape_error_on_mismatching_num_rows ( ) {
556
561
let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
557
562
let b: Array1 < f64 > = array ! [ 1. , 2. ] ;
558
- let res = a. least_squares ( & b) ;
559
- match res {
560
- Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lax:: error:: Error :: InvalidShape ) => { }
561
- _ => panic ! ( "Expected Err()" ) ,
562
- }
563
- }
564
-
565
- #[ test]
566
- fn incompatible_shape_error_on_mismatching_layout ( ) {
567
- let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
568
- let b = array ! [ [ 1. ] , [ 2. ] ] . t ( ) . to_owned ( ) ;
569
- assert_eq ! ( b. layout( ) . unwrap( ) , MatrixLayout :: F { col: 2 , lda: 1 } ) ;
570
-
571
- let res = a. least_squares ( & b) ;
572
- match res {
573
- Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lax:: error:: Error :: InvalidShape ) => { }
574
- _ => panic ! ( "Expected Err()" ) ,
563
+ match a. least_squares ( & b) {
564
+ Err ( LinalgError :: Shape ( e) ) if e. kind ( ) == ErrorKind :: IncompatibleShape => { }
565
+ _ => panic ! ( "Should be raise IncompatibleShape" ) ,
575
566
}
576
567
}
577
568
}
0 commit comments