Skip to content

Commit 62e40d1

Browse files
committed
Replace FlagSVD by UVTFlag
1 parent 4912a11 commit 62e40d1

File tree

2 files changed

+18
-31
lines changed

2 files changed

+18
-31
lines changed

lax/src/flags.rs

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -90,29 +90,6 @@ impl JobEv {
9090
}
9191
}
9292

93-
#[repr(u8)]
94-
#[derive(Debug, Copy, Clone)]
95-
pub enum FlagSVD {
96-
All = b'A',
97-
// OverWrite = b'O',
98-
// Separately = b'S',
99-
No = b'N',
100-
}
101-
102-
impl FlagSVD {
103-
pub fn from_bool(calc_uv: bool) -> Self {
104-
if calc_uv {
105-
FlagSVD::All
106-
} else {
107-
FlagSVD::No
108-
}
109-
}
110-
111-
pub fn as_ptr(&self) -> *const i8 {
112-
self as *const FlagSVD as *const i8
113-
}
114-
}
115-
11693
/// Specifies how many of the columns of *U* and rows of *V*ᵀ are computed and returned.
11794
///
11895
/// For an input array of shape *m*×*n*, the following are computed:
@@ -128,6 +105,14 @@ pub enum UVTFlag {
128105
}
129106

130107
impl UVTFlag {
108+
pub fn from_bool(calc_uv: bool) -> Self {
109+
if calc_uv {
110+
UVTFlag::Full
111+
} else {
112+
UVTFlag::None
113+
}
114+
}
115+
131116
pub fn as_ptr(&self) -> *const i8 {
132117
self as *const UVTFlag as *const i8
133118
}

lax/src/svd.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,26 @@ macro_rules! impl_svd {
3232
impl SVD_ for $scalar {
3333
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result<SVDOutput<Self>> {
3434
let ju = match l {
35-
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_u),
36-
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_vt),
35+
MatrixLayout::F { .. } => UVTFlag::from_bool(calc_u),
36+
MatrixLayout::C { .. } => UVTFlag::from_bool(calc_vt),
3737
};
3838
let jvt = match l {
39-
MatrixLayout::F { .. } => FlagSVD::from_bool(calc_vt),
40-
MatrixLayout::C { .. } => FlagSVD::from_bool(calc_u),
39+
MatrixLayout::F { .. } => UVTFlag::from_bool(calc_vt),
40+
MatrixLayout::C { .. } => UVTFlag::from_bool(calc_u),
4141
};
4242

4343
let m = l.lda();
4444
let mut u = match ju {
45-
FlagSVD::All => Some(unsafe { vec_uninit( (m * m) as usize) }),
46-
FlagSVD::No => None,
45+
UVTFlag::Full => Some(unsafe { vec_uninit( (m * m) as usize) }),
46+
UVTFlag::None => None,
47+
_ => unimplemented!("SVD with partial vector output is not supported yet")
4748
};
4849

4950
let n = l.len();
5051
let mut vt = match jvt {
51-
FlagSVD::All => Some(unsafe { vec_uninit( (n * n) as usize) }),
52-
FlagSVD::No => None,
52+
UVTFlag::Full => Some(unsafe { vec_uninit( (n * n) as usize) }),
53+
UVTFlag::None => None,
54+
_ => unimplemented!("SVD with partial vector output is not supported yet")
5355
};
5456

5557
let k = std::cmp::min(m, n);

0 commit comments

Comments
 (0)