Skip to content

Commit 17f9fb8

Browse files
committed
Split definition of tridiagonal matrix
1 parent ad19250 commit 17f9fb8

File tree

2 files changed

+110
-103
lines changed

2 files changed

+110
-103
lines changed

lax/src/tridiagonal/matrix.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
use crate::layout::*;
2+
use cauchy::*;
3+
use num_traits::Zero;
4+
use std::ops::{Index, IndexMut};
5+
6+
/// Represents a tridiagonal matrix as 3 one-dimensional vectors.
7+
///
8+
/// ```text
9+
/// [d0, u1, 0, ..., 0,
10+
/// l1, d1, u2, ...,
11+
/// 0, l2, d2,
12+
/// ... ..., u{n-1},
13+
/// 0, ..., l{n-1}, d{n-1},]
14+
/// ```
15+
#[derive(Clone, PartialEq, Eq)]
16+
pub struct Tridiagonal<A: Scalar> {
17+
/// layout of raw matrix
18+
pub l: MatrixLayout,
19+
/// (n-1) sub-diagonal elements of matrix.
20+
pub dl: Vec<A>,
21+
/// (n) diagonal elements of matrix.
22+
pub d: Vec<A>,
23+
/// (n-1) super-diagonal elements of matrix.
24+
pub du: Vec<A>,
25+
}
26+
27+
impl<A: Scalar> Tridiagonal<A> {
28+
pub fn opnorm_one(&self) -> A::Real {
29+
let mut col_sum: Vec<A::Real> = self.d.iter().map(|val| val.abs()).collect();
30+
for i in 0..col_sum.len() {
31+
if i < self.dl.len() {
32+
col_sum[i] += self.dl[i].abs();
33+
}
34+
if i > 0 {
35+
col_sum[i] += self.du[i - 1].abs();
36+
}
37+
}
38+
let mut max = A::Real::zero();
39+
for &val in &col_sum {
40+
if max < val {
41+
max = val;
42+
}
43+
}
44+
max
45+
}
46+
}
47+
48+
impl<A: Scalar> Index<(i32, i32)> for Tridiagonal<A> {
49+
type Output = A;
50+
#[inline]
51+
fn index(&self, (row, col): (i32, i32)) -> &A {
52+
let (n, _) = self.l.size();
53+
assert!(
54+
std::cmp::max(row, col) < n,
55+
"ndarray: index {:?} is out of bounds for array of shape {}",
56+
[row, col],
57+
n
58+
);
59+
match row - col {
60+
0 => &self.d[row as usize],
61+
1 => &self.dl[col as usize],
62+
-1 => &self.du[row as usize],
63+
_ => panic!(
64+
"ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
65+
[row, col]
66+
),
67+
}
68+
}
69+
}
70+
71+
impl<A: Scalar> Index<[i32; 2]> for Tridiagonal<A> {
72+
type Output = A;
73+
#[inline]
74+
fn index(&self, [row, col]: [i32; 2]) -> &A {
75+
&self[(row, col)]
76+
}
77+
}
78+
79+
impl<A: Scalar> IndexMut<(i32, i32)> for Tridiagonal<A> {
80+
#[inline]
81+
fn index_mut(&mut self, (row, col): (i32, i32)) -> &mut A {
82+
let (n, _) = self.l.size();
83+
assert!(
84+
std::cmp::max(row, col) < n,
85+
"ndarray: index {:?} is out of bounds for array of shape {}",
86+
[row, col],
87+
n
88+
);
89+
match row - col {
90+
0 => &mut self.d[row as usize],
91+
1 => &mut self.dl[col as usize],
92+
-1 => &mut self.du[row as usize],
93+
_ => panic!(
94+
"ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
95+
[row, col]
96+
),
97+
}
98+
}
99+
}
100+
101+
impl<A: Scalar> IndexMut<[i32; 2]> for Tridiagonal<A> {
102+
#[inline]
103+
fn index_mut(&mut self, [row, col]: [i32; 2]) -> &mut A {
104+
&mut self[(row, col)]
105+
}
106+
}

lax/src/tridiagonal.rs renamed to lax/src/tridiagonal/mod.rs

Lines changed: 4 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,13 @@
11
//! Implement linear solver using LU decomposition
22
//! for tridiagonal matrix
33
4+
mod matrix;
5+
6+
pub use matrix::*;
7+
48
use crate::{error::*, layout::*, *};
59
use cauchy::*;
610
use num_traits::Zero;
7-
use std::ops::{Index, IndexMut};
8-
9-
/// Represents a tridiagonal matrix as 3 one-dimensional vectors.
10-
///
11-
/// ```text
12-
/// [d0, u1, 0, ..., 0,
13-
/// l1, d1, u2, ...,
14-
/// 0, l2, d2,
15-
/// ... ..., u{n-1},
16-
/// 0, ..., l{n-1}, d{n-1},]
17-
/// ```
18-
#[derive(Clone, PartialEq, Eq)]
19-
pub struct Tridiagonal<A: Scalar> {
20-
/// layout of raw matrix
21-
pub l: MatrixLayout,
22-
/// (n-1) sub-diagonal elements of matrix.
23-
pub dl: Vec<A>,
24-
/// (n) diagonal elements of matrix.
25-
pub d: Vec<A>,
26-
/// (n-1) super-diagonal elements of matrix.
27-
pub du: Vec<A>,
28-
}
29-
30-
impl<A: Scalar> Tridiagonal<A> {
31-
fn opnorm_one(&self) -> A::Real {
32-
let mut col_sum: Vec<A::Real> = self.d.iter().map(|val| val.abs()).collect();
33-
for i in 0..col_sum.len() {
34-
if i < self.dl.len() {
35-
col_sum[i] += self.dl[i].abs();
36-
}
37-
if i > 0 {
38-
col_sum[i] += self.du[i - 1].abs();
39-
}
40-
}
41-
let mut max = A::Real::zero();
42-
for &val in &col_sum {
43-
if max < val {
44-
max = val;
45-
}
46-
}
47-
max
48-
}
49-
}
5011

5112
/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`.
5213
#[derive(Clone, PartialEq)]
@@ -65,66 +26,6 @@ pub struct LUFactorizedTridiagonal<A: Scalar> {
6526
a_opnorm_one: A::Real,
6627
}
6728

68-
impl<A: Scalar> Index<(i32, i32)> for Tridiagonal<A> {
69-
type Output = A;
70-
#[inline]
71-
fn index(&self, (row, col): (i32, i32)) -> &A {
72-
let (n, _) = self.l.size();
73-
assert!(
74-
std::cmp::max(row, col) < n,
75-
"ndarray: index {:?} is out of bounds for array of shape {}",
76-
[row, col],
77-
n
78-
);
79-
match row - col {
80-
0 => &self.d[row as usize],
81-
1 => &self.dl[col as usize],
82-
-1 => &self.du[row as usize],
83-
_ => panic!(
84-
"ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
85-
[row, col]
86-
),
87-
}
88-
}
89-
}
90-
91-
impl<A: Scalar> Index<[i32; 2]> for Tridiagonal<A> {
92-
type Output = A;
93-
#[inline]
94-
fn index(&self, [row, col]: [i32; 2]) -> &A {
95-
&self[(row, col)]
96-
}
97-
}
98-
99-
impl<A: Scalar> IndexMut<(i32, i32)> for Tridiagonal<A> {
100-
#[inline]
101-
fn index_mut(&mut self, (row, col): (i32, i32)) -> &mut A {
102-
let (n, _) = self.l.size();
103-
assert!(
104-
std::cmp::max(row, col) < n,
105-
"ndarray: index {:?} is out of bounds for array of shape {}",
106-
[row, col],
107-
n
108-
);
109-
match row - col {
110-
0 => &mut self.d[row as usize],
111-
1 => &mut self.dl[col as usize],
112-
-1 => &mut self.du[row as usize],
113-
_ => panic!(
114-
"ndarray-linalg::tridiagonal: index {:?} is not tridiagonal element",
115-
[row, col]
116-
),
117-
}
118-
}
119-
}
120-
121-
impl<A: Scalar> IndexMut<[i32; 2]> for Tridiagonal<A> {
122-
#[inline]
123-
fn index_mut(&mut self, [row, col]: [i32; 2]) -> &mut A {
124-
&mut self[(row, col)]
125-
}
126-
}
127-
12829
/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
12930
pub trait Tridiagonal_: Scalar + Sized {
13031
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using

0 commit comments

Comments
 (0)