Skip to content

Commit da1176f

Browse files
author
Martin Benes
committed
fixed wow now maybe?
1 parent cb53c09 commit da1176f

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

src/wow.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@ use numpy::{ndarray::Array, ndarray::Array1, ndarray::Array2, ndarray::Array3, n
88

99

1010
// ---------- internal helper, NOT exposed to Python ----------
11-
fn daubechies8() -> (Array3<f32>, Vec<(Array1<f32>, Array1<f32>)>) {
12-
let hpdf: [f32; 16] = [
11+
fn daubechies8() -> (Array3<f64>, Vec<(Array1<f64>, Array1<f64>)>) {
12+
let hpdf: [f64; 16] = [
1313
-0.0544158422, 0.3128715909, -0.6756307363, 0.5853546837,
1414
0.0158291053, -0.2840155430, -0.0004724846, 0.1287474266,
1515
0.0173693010, -0.0440882539, -0.0139810279, 0.0087460940,
1616
0.0048703530, -0.0003917404, -0.0006754494, -0.0001174768
1717
];
1818

1919
// build lpdf
20-
let mut lpdf = [0f32; 16];
20+
let mut lpdf = [0f64; 16];
2121
for i in 0..16 {
22-
lpdf[i] = ((-1f32).powi(i as i32)) * hpdf[15 - i];
22+
lpdf[i] = ((-1f64).powi(i as i32)) * hpdf[15 - i];
2323
}
2424

2525
let h = Array::from_shape_vec((16, 1), hpdf.to_vec()).unwrap();
@@ -49,9 +49,9 @@ fn reflect_index(i: isize, n: isize) -> isize {
4949
}
5050

5151
/// Symmetric pad a 2D array
52-
fn pad_symmetric(input: &Array2<f32>, pad_v: usize, pad_h: usize) -> Array2<f32> {
52+
fn pad_symmetric(input: &Array2<f64>, pad_v: usize, pad_h: usize) -> Array2<f64> {
5353
let (h, w) = input.dim();
54-
let mut output = Array2::<f32>::zeros((h + 2*pad_v, w + 2*pad_h));
54+
let mut output = Array2::<f64>::zeros((h + 2*pad_v, w + 2*pad_h));
5555

5656
for i in 0..output.nrows() {
5757
for j in 0..output.ncols() {
@@ -64,18 +64,18 @@ fn pad_symmetric(input: &Array2<f32>, pad_v: usize, pad_h: usize) -> Array2<f32>
6464
}
6565

6666
/// 2D convolution with symmetric padding and mode='same'
67-
fn convolve2d(input: &Array2<f32>, kernel: &Array2<f32>) -> Array2<f32> {
67+
fn convolve2d(input: &Array2<f64>, kernel: &Array2<f64>) -> Array2<f64> {
6868
let (h, w) = input.dim();
6969
let (kh, kw) = kernel.dim();
7070
let pad_h = kh / 2;
7171
let pad_w = kw / 2;
7272
let pad = pad_h.max(pad_w);
7373
let input_pad = pad_symmetric(input, pad_h, pad_w);
7474

75-
let mut output = Array2::<f32>::zeros((h, w));
75+
let mut output = Array2::<f64>::zeros((h, w));
7676
for i in 0..h {
7777
for j in 0..w {
78-
let mut sum = 0.0f32;
78+
let mut sum = 0.0f64;
7979
for u in 0..kh {
8080
for v in 0..kw {
8181
let x = i + u;
@@ -92,13 +92,13 @@ fn convolve2d(input: &Array2<f32>, kernel: &Array2<f32>) -> Array2<f32> {
9292

9393

9494

95-
fn convolve1d_horizontal(input: &Array2<f32>, kernel: &[f32]) -> Array2<f32> {
95+
fn convolve1d_horizontal(input: &Array2<f64>, kernel: &[f64]) -> Array2<f64> {
9696
let (h, w) = input.dim();
9797
let k = kernel.len();
9898
let pad = k / 2;
9999
let input_pad = pad_symmetric(input, 0, pad);
100100

101-
let mut out = Array2::<f32>::zeros((h, w));
101+
let mut out = Array2::<f64>::zeros((h, w));
102102

103103
for i in 0..h {
104104
for j in 0..w {
@@ -113,7 +113,7 @@ fn convolve1d_horizontal(input: &Array2<f32>, kernel: &[f32]) -> Array2<f32> {
113113
out
114114
}
115115

116-
fn convolve1d_vertical(input: &Array2<f32>, kernel: &[f32]) -> Array2<f32> {
116+
fn convolve1d_vertical(input: &Array2<f64>, kernel: &[f64]) -> Array2<f64> {
117117
// transpose the input
118118
let input_t = input.t();
119119
let mut tmp = convolve1d_horizontal(&input_t.to_owned(), kernel);
@@ -135,10 +135,10 @@ fn convolve1d_vertical(input: &Array2<f32>, kernel: &[f32]) -> Array2<f32> {
135135
// #[pyo3(signature = (x0))]
136136
#[pyfunction]
137137
#[pyo3(signature = (x0, p = -1.0))]
138-
fn compute_cost<'py>(py: Python<'py>, x0: PyReadonlyArray2<'py, u8>, p: f32)
139-
-> PyResult<Py<PyArray2<f32>>> {
138+
fn compute_cost<'py>(py: Python<'py>, x0: PyReadonlyArray2<'py, u8>, p: f64)
139+
-> PyResult<Py<PyArray2<f64>>> {
140140

141-
let input = x0.as_array().mapv(|v| v as f32);
141+
let input = x0.as_array().mapv(|v| v as f64);
142142
let (h, w) = input.dim();
143143
let mut x0_pad = pad_symmetric(&input, 16 as usize, 16 as usize);
144144

@@ -191,14 +191,14 @@ fn compute_cost<'py>(py: Python<'py>, x0: PyReadonlyArray2<'py, u8>, p: f32)
191191
xi.push(x_crop);
192192
}
193193

194-
// convert xi Vec<Array2<f32>> into a single Array3<f32> of shape (3, h, w)
194+
// convert xi Vec<Array2<f64>> into a single Array3<f64> of shape (3, h, w)
195195
let xi_3d = Array3::from_shape_vec(
196196
(3, h, w),
197197
xi.into_iter().flat_map(|arr| arr.into_raw_vec()).collect()
198198
).unwrap();
199199

200200
// compute sum over channels of xi_i^p
201-
let rho = xi_3d.mapv(|v| v.max(f32::EPSILON)).mapv(|v| v.powf(p)).sum_axis(Axis(0)).mapv(|v| v.powf(-1.0f32 / p));
201+
let rho = xi_3d.mapv(|v| v.max(f64::EPSILON)).mapv(|v| v.powf(p)).sum_axis(Axis(0)).mapv(|v| v.powf(-1.0f64 / p));
202202
Ok(PyArray2::from_owned_array(py, rho).into())
203203
}
204204

test/test_wow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def test_simulate(self, f):
6767
x1_ref = np.array(Image.open(STEGO_DIR / f'{f}.png'))
6868
np.testing.assert_array_equal(x1, x1_ref)
6969

70-
7170
@parameterized.expand([[f] for f in defs.TEST_IMAGES])
7271
def test_rust(self, fname: str):
7372
self._logger.info(f'TestWOW.test_rust({fname=})')

0 commit comments

Comments
 (0)