Skip to content

Commit 00b044b

Browse files
committed
Fix error handling
1 parent 5ad4343 commit 00b044b

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

ndarray-linalg/src/least_squares.rs

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,9 @@ where
267267
&mut self,
268268
rhs: &mut ArrayBase<D, Ix1>,
269269
) -> Result<LeastSquaresResult<E, Ix1>> {
270+
if self.shape()[0] != rhs.shape()[0] {
271+
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
272+
}
270273
let (m, n) = (self.shape()[0], self.shape()[1]);
271274
if n > m {
272275
// we need a new rhs b/c it will be overwritten with the solution
@@ -285,15 +288,15 @@ fn compute_least_squares_srhs<E, D1, D2>(
285288
rhs: &mut ArrayBase<D2, Ix1>,
286289
) -> Result<LeastSquaresResult<E, Ix1>>
287290
where
288-
E: Scalar + Lapack + LeastSquaresSvdDivideConquer_,
291+
E: Scalar + Lapack,
289292
D1: DataMut<Elem = E>,
290293
D2: DataMut<Elem = E>,
291294
{
292295
let LeastSquaresOutput::<E> {
293296
singular_values,
294297
rank,
295298
} = unsafe {
296-
<E as LeastSquaresSvdDivideConquer_>::least_squares(
299+
E::least_squares(
297300
a.layout()?,
298301
a.as_allocated_mut()?,
299302
rhs.as_slice_memory_order_mut()
@@ -348,6 +351,9 @@ where
348351
&mut self,
349352
rhs: &mut ArrayBase<D, Ix2>,
350353
) -> Result<LeastSquaresResult<E, Ix2>> {
354+
if self.shape()[0] != rhs.shape()[0] {
355+
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
356+
}
351357
let (m, n) = (self.shape()[0], self.shape()[1]);
352358
if n > m {
353359
// we need a new rhs b/c it will be overwritten with the solution
@@ -550,28 +556,13 @@ mod tests {
550556
//
551557
// Testing error cases
552558
//
553-
554559
#[test]
555560
fn incompatible_shape_error_on_mismatching_num_rows() {
556561
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
557562
let b: Array1<f64> = array![1., 2.];
558-
let res = a.least_squares(&b);
559-
match res {
560-
Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {}
561-
_ => panic!("Expected Err()"),
562-
}
563-
}
564-
565-
#[test]
566-
fn incompatible_shape_error_on_mismatching_layout() {
567-
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
568-
let b = array![[1.], [2.]].t().to_owned();
569-
assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 });
570-
571-
let res = a.least_squares(&b);
572-
match res {
573-
Err(LinalgError::Lapack(err)) if matches!(err, lax::error::Error::InvalidShape) => {}
574-
_ => panic!("Expected Err()"),
563+
match a.least_squares(&b) {
564+
Err(LinalgError::Shape(e)) if e.kind() == ErrorKind::IncompatibleShape => {}
565+
_ => panic!("Should be raise IncompatibleShape"),
575566
}
576567
}
577568
}

0 commit comments

Comments
 (0)