1111from pywt ._functions import scale2frequency
1212from torch .fft import fft , ifft
1313
14+ __all__ = ["cwt" ]
15+
1416
1517def _next_fast_len (n : int ) -> int :
1618 """Round up size to the nearest power of two.
@@ -24,10 +26,10 @@ def _next_fast_len(n: int) -> int:
2426
2527def cwt (
2628 data : torch .Tensor ,
27- scales : Union [np .ndarray , torch .Tensor ], # type: ignore
29+ scales : Union [np .ndarray , torch .Tensor ],
2830 wavelet : Union [ContinuousWavelet , str ],
2931 sampling_period : float = 1.0 ,
30- ) -> tuple [torch .Tensor , np .ndarray ]: # type: ignore
32+ ) -> tuple [torch .Tensor , np .ndarray ]:
3133 """Compute the single-dimensional continuous wavelet transform.
3234
3335 This function is a PyTorch port of pywt.cwt as found at:
@@ -185,11 +187,11 @@ def _integrate_wavelet(
185187 """
186188
187189 def _integrate (
188- arr : Union [np .ndarray , torch .Tensor ], # type: ignore
189- step : Union [np .ndarray , torch .Tensor ], # type: ignore
190- ) -> Union [np .ndarray , torch .Tensor ]: # type: ignore
190+ arr : Union [np .ndarray , torch .Tensor ],
191+ step : Union [np .ndarray , torch .Tensor ],
192+ ) -> Union [np .ndarray , torch .Tensor ]:
191193 if type (arr ) is np .ndarray :
192- integral = np .cumsum (arr )
194+ integral : Any = np .cumsum (arr )
193195 elif type (arr ) is torch .Tensor :
194196 integral = torch .cumsum (arr , - 1 )
195197 else :
@@ -212,12 +214,12 @@ def _integrate(
212214 return _integrate (psi , step ), x
213215
214216 elif len (functions_approximations ) == 3 : # orthogonal wavelet
215- _ , psi , x = functions_approximations
217+ _ , psi , x = functions_approximations # type: ignore
216218 step = x [1 ] - x [0 ]
217219 return _integrate (psi , step ), x
218220
219221 else : # biorthogonal wavelet
220- _ , psi_d , _ , psi_r , x = functions_approximations
222+ _ , psi_d , _ , psi_r , x = functions_approximations # type: ignore
221223 step = x [1 ] - x [0 ]
222224 return _integrate (psi_d , step ), _integrate (psi_r , step ), x
223225
@@ -248,7 +250,11 @@ def __init__(self, name: str):
248250 )
249251
250252 def __call__ (self , grid_values : torch .Tensor ) -> torch .Tensor :
251- """Return numerical values for the wavelet on a grid."""
253+ """Return numerical values for the wavelet on a grid.
254+
255+ Raises:
256+ NotImplementedError: If this call is not overwritten by a child.
257+ """
252258 raise NotImplementedError
253259
254260 @property
0 commit comments