Skip to content

Commit 2ce0be8

Browse files
committed
Move check_square() to AllocatedArray::ensure_square()
1 parent 3464fe0 commit 2ce0be8

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

src/layout.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ pub trait AllocatedArray {
7575
type Elem;
7676
fn layout(&self) -> Result<MatrixLayout>;
7777
fn square_layout(&self) -> Result<MatrixLayout>;
78+
/// Returns Ok iff the matrix is square (without computing the layout).
79+
fn ensure_square(&self) -> Result<()>;
7880
fn as_allocated(&self) -> Result<&[Self::Elem]>;
7981
}
8082

@@ -110,6 +112,14 @@ where
110112
}
111113
}
112114

115+
fn ensure_square(&self) -> Result<()> {
116+
if self.is_square() {
117+
Ok(())
118+
} else {
119+
Err(NotSquareError::new(self.rows() as i32, self.cols() as i32).into())
120+
}
121+
}
122+
113123
fn as_allocated(&self) -> Result<&[A]> {
114124
Ok(self.as_slice_memory_order().ok_or(MemoryContError::new())?)
115125
}

src/solve.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -366,21 +366,13 @@ where
366366
pivot_sign * upper_sign * ln_det.exp()
367367
}
368368

369-
fn check_square<S: Data>(a: &ArrayBase<S, Ix2>) -> Result<()> {
370-
if a.is_square() {
371-
Ok(())
372-
} else {
373-
Err(NotSquareError::new(a.rows() as i32, a.cols() as i32).into())
374-
}
375-
}
376-
377369
impl<A, S> Determinant<A> for LUFactorized<S>
378370
where
379371
A: Scalar,
380372
S: Data<Elem = A>,
381373
{
382374
fn det(&self) -> Result<A> {
383-
check_square(&self.a)?;
375+
self.a.ensure_square()?;
384376
Ok(lu_det(self.ipiv.iter().cloned(), self.a.diag().iter()))
385377
}
386378
}
@@ -391,7 +383,7 @@ where
391383
S: Data<Elem = A>,
392384
{
393385
fn det_into(self) -> Result<A> {
394-
check_square(&self.a)?;
386+
self.a.ensure_square()?;
395387
Ok(lu_det(self.ipiv.into_iter(), self.a.into_diag().iter()))
396388
}
397389
}
@@ -402,7 +394,7 @@ where
402394
S: Data<Elem = A>,
403395
{
404396
fn det(&self) -> Result<A> {
405-
check_square(&self)?;
397+
self.ensure_square()?;
406398
match self.factorize() {
407399
Ok(fac) => fac.det(),
408400
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::zero()),
@@ -417,7 +409,7 @@ where
417409
S: DataMut<Elem = A>,
418410
{
419411
fn det_into(self) -> Result<A> {
420-
check_square(&self)?;
412+
self.ensure_square()?;
421413
match self.factorize_into() {
422414
Ok(fac) => fac.det_into(),
423415
Err(LinalgError::Lapack(LapackError { return_code })) if return_code > 0 => Ok(A::zero()),

0 commit comments

Comments
 (0)