@@ -8,18 +8,18 @@ use numpy::{ndarray::Array, ndarray::Array1, ndarray::Array2, ndarray::Array3, n
88
99
1010// ---------- internal helper, NOT exposed to Python ----------
11- fn daubechies8 ( ) -> ( Array3 < f32 > , Vec < ( Array1 < f32 > , Array1 < f32 > ) > ) {
12- let hpdf: [ f32 ; 16 ] = [
11+ fn daubechies8 ( ) -> ( Array3 < f64 > , Vec < ( Array1 < f64 > , Array1 < f64 > ) > ) {
12+ let hpdf: [ f64 ; 16 ] = [
1313 -0.0544158422 , 0.3128715909 , -0.6756307363 , 0.5853546837 ,
1414 0.0158291053 , -0.2840155430 , -0.0004724846 , 0.1287474266 ,
1515 0.0173693010 , -0.0440882539 , -0.0139810279 , 0.0087460940 ,
1616 0.0048703530 , -0.0003917404 , -0.0006754494 , -0.0001174768
1717 ] ;
1818
1919 // build lpdf
20- let mut lpdf = [ 0f32 ; 16 ] ;
20+ let mut lpdf = [ 0f64 ; 16 ] ;
2121 for i in 0 ..16 {
22- lpdf[ i] = ( ( -1f32 ) . powi ( i as i32 ) ) * hpdf[ 15 - i] ;
22+ lpdf[ i] = ( ( -1f64 ) . powi ( i as i32 ) ) * hpdf[ 15 - i] ;
2323 }
2424
2525 let h = Array :: from_shape_vec ( ( 16 , 1 ) , hpdf. to_vec ( ) ) . unwrap ( ) ;
@@ -49,9 +49,9 @@ fn reflect_index(i: isize, n: isize) -> isize {
4949}
5050
5151/// Symmetric pad a 2D array
52- fn pad_symmetric ( input : & Array2 < f32 > , pad_v : usize , pad_h : usize ) -> Array2 < f32 > {
52+ fn pad_symmetric ( input : & Array2 < f64 > , pad_v : usize , pad_h : usize ) -> Array2 < f64 > {
5353 let ( h, w) = input. dim ( ) ;
54- let mut output = Array2 :: < f32 > :: zeros ( ( h + 2 * pad_v, w + 2 * pad_h) ) ;
54+ let mut output = Array2 :: < f64 > :: zeros ( ( h + 2 * pad_v, w + 2 * pad_h) ) ;
5555
5656 for i in 0 ..output. nrows ( ) {
5757 for j in 0 ..output. ncols ( ) {
@@ -64,18 +64,18 @@ fn pad_symmetric(input: &Array2<f32>, pad_v: usize, pad_h: usize) -> Array2<f32>
6464}
6565
6666/// 2D convolution with symmetric padding and mode='same'
67- fn convolve2d ( input : & Array2 < f32 > , kernel : & Array2 < f32 > ) -> Array2 < f32 > {
67+ fn convolve2d ( input : & Array2 < f64 > , kernel : & Array2 < f64 > ) -> Array2 < f64 > {
6868 let ( h, w) = input. dim ( ) ;
6969 let ( kh, kw) = kernel. dim ( ) ;
7070 let pad_h = kh / 2 ;
7171 let pad_w = kw / 2 ;
7272 let pad = pad_h. max ( pad_w) ;
7373 let input_pad = pad_symmetric ( input, pad_h, pad_w) ;
7474
75- let mut output = Array2 :: < f32 > :: zeros ( ( h, w) ) ;
75+ let mut output = Array2 :: < f64 > :: zeros ( ( h, w) ) ;
7676 for i in 0 ..h {
7777 for j in 0 ..w {
78- let mut sum = 0.0f32 ;
78+ let mut sum = 0.0f64 ;
7979 for u in 0 ..kh {
8080 for v in 0 ..kw {
8181 let x = i + u;
@@ -92,13 +92,13 @@ fn convolve2d(input: &Array2<f32>, kernel: &Array2<f32>) -> Array2<f32> {
9292
9393
9494
95- fn convolve1d_horizontal ( input : & Array2 < f32 > , kernel : & [ f32 ] ) -> Array2 < f32 > {
95+ fn convolve1d_horizontal ( input : & Array2 < f64 > , kernel : & [ f64 ] ) -> Array2 < f64 > {
9696 let ( h, w) = input. dim ( ) ;
9797 let k = kernel. len ( ) ;
9898 let pad = k / 2 ;
9999 let input_pad = pad_symmetric ( input, 0 , pad) ;
100100
101- let mut out = Array2 :: < f32 > :: zeros ( ( h, w) ) ;
101+ let mut out = Array2 :: < f64 > :: zeros ( ( h, w) ) ;
102102
103103 for i in 0 ..h {
104104 for j in 0 ..w {
@@ -113,7 +113,7 @@ fn convolve1d_horizontal(input: &Array2<f32>, kernel: &[f32]) -> Array2<f32> {
113113 out
114114}
115115
116- fn convolve1d_vertical ( input : & Array2 < f32 > , kernel : & [ f32 ] ) -> Array2 < f32 > {
116+ fn convolve1d_vertical ( input : & Array2 < f64 > , kernel : & [ f64 ] ) -> Array2 < f64 > {
117117 // transpose the input
118118 let input_t = input. t ( ) ;
119119 let mut tmp = convolve1d_horizontal ( & input_t. to_owned ( ) , kernel) ;
@@ -135,10 +135,10 @@ fn convolve1d_vertical(input: &Array2<f32>, kernel: &[f32]) -> Array2<f32> {
135135// #[pyo3(signature = (x0))]
136136#[ pyfunction]
137137#[ pyo3( signature = ( x0, p = -1.0 ) ) ]
138- fn compute_cost < ' py > ( py : Python < ' py > , x0 : PyReadonlyArray2 < ' py , u8 > , p : f32 )
139- -> PyResult < Py < PyArray2 < f32 > > > {
138+ fn compute_cost < ' py > ( py : Python < ' py > , x0 : PyReadonlyArray2 < ' py , u8 > , p : f64 )
139+ -> PyResult < Py < PyArray2 < f64 > > > {
140140
141- let input = x0. as_array ( ) . mapv ( |v| v as f32 ) ;
141+ let input = x0. as_array ( ) . mapv ( |v| v as f64 ) ;
142142 let ( h, w) = input. dim ( ) ;
143143 let mut x0_pad = pad_symmetric ( & input, 16 as usize , 16 as usize ) ;
144144
@@ -191,14 +191,14 @@ fn compute_cost<'py>(py: Python<'py>, x0: PyReadonlyArray2<'py, u8>, p: f32)
191191 xi. push ( x_crop) ;
192192 }
193193
194- // convert xi Vec<Array2<f32 >> into a single Array3<f32 > of shape (3, h, w)
194+ // convert xi Vec<Array2<f64 >> into a single Array3<f64 > of shape (3, h, w)
195195 let xi_3d = Array3 :: from_shape_vec (
196196 ( 3 , h, w) ,
197197 xi. into_iter ( ) . flat_map ( |arr| arr. into_raw_vec ( ) ) . collect ( )
198198 ) . unwrap ( ) ;
199199
200200 // compute sum over channels of xi_i^p
201- let rho = xi_3d. mapv ( |v| v. max ( f32 :: EPSILON ) ) . mapv ( |v| v. powf ( p) ) . sum_axis ( Axis ( 0 ) ) . mapv ( |v| v. powf ( -1.0f32 / p) ) ;
201+ let rho = xi_3d. mapv ( |v| v. max ( f64 :: EPSILON ) ) . mapv ( |v| v. powf ( p) ) . sum_axis ( Axis ( 0 ) ) . mapv ( |v| v. powf ( -1.0f64 / p) ) ;
202202 Ok ( PyArray2 :: from_owned_array ( py, rho) . into ( ) )
203203}
204204
0 commit comments