44"""
55
66# Created by moritz wolter, 14.04.20
7- from typing import List , Optional , Sequence , Tuple , Union
7+ from typing import List , Optional , Sequence , Tuple , Union , cast
88
99import pywt
1010import torch
1818 _pad_symmetric ,
1919 _unfold_axes ,
2020)
21+ from .constants import BoundaryMode
2122
2223
2324def _create_tensor (
@@ -106,7 +107,7 @@ def _get_pad(data_len: int, filt_len: int) -> Tuple[int, int]:
106107 return padr , padl
107108
108109
109- def _translate_boundary_strings (pywt_mode : str ) -> str :
110+ def _translate_boundary_strings (pywt_mode : BoundaryMode ) -> str :
110111 """Translate pywt mode strings to PyTorch mode strings.
111112
112113 We support constant, zero, reflect, and periodic.
@@ -118,24 +119,25 @@ def _translate_boundary_strings(pywt_mode: str) -> str:
118119
119120 """
120121 if pywt_mode == "constant" :
121- pt_mode = "replicate"
122+ return "replicate"
122123 elif pywt_mode == "zero" :
123- pt_mode = "constant"
124+ return "constant"
124125 elif pywt_mode == "reflect" :
125- pt_mode = pywt_mode
126+ return pywt_mode
126127 elif pywt_mode == "periodic" :
127- pt_mode = "circular"
128+ return "circular"
128129 elif pywt_mode == "symmetric" :
129130 # pytorch does not support symmetric mode,
130131 # we have our own implementation.
131- pt_mode = pywt_mode
132- else :
133- raise ValueError ("Padding mode not supported." )
134- return pt_mode
132+ return pywt_mode
133+ raise ValueError (f"Padding mode not supported: { pywt_mode } " )
135134
136135
137136def _fwt_pad (
138- data : torch .Tensor , wavelet : Union [Wavelet , str ], mode : str = "reflect"
137+ data : torch .Tensor ,
138+ wavelet : Union [Wavelet , str ],
139+ * ,
140+ mode : Optional [BoundaryMode ] = None ,
139141) -> torch .Tensor :
140142 """Pad the input signal to make the fwt matrix work.
141143
@@ -145,29 +147,26 @@ def _fwt_pad(
145147 data (torch.Tensor): Input data ``[batch_size, 1, time]``
146148 wavelet (Wavelet or str): A pywt wavelet compatible object or
147149 the name of a pywt wavelet.
148- mode (str): The desired way to pad. The following methods are supported::
149-
150- "reflect", "zero", "constant", "periodic", "symmetric".
151-
152- Refection padding mirrors samples along the border.
153- Zero padding pads zeros.
154- Constant padding replicates border values.
155- Periodic padding cyclically repeats samples.
156- This function defaults to reflect.
150+ mode :
151+ The desired padding mode for extending the signal along the edges.
152+ Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
157153
158154 Returns:
159155 torch.Tensor: A PyTorch tensor with the padded input data
160156
161157 """
162158 wavelet = _as_wavelet (wavelet )
159+
163160 # convert pywt to pytorch convention.
164- mode = _translate_boundary_strings (mode )
161+ if mode is None :
162+ mode = cast (BoundaryMode , "reflect" )
163+ pytorch_mode = _translate_boundary_strings (mode )
165164
166165 padr , padl = _get_pad (data .shape [- 1 ], _get_len (wavelet ))
167- if mode == "symmetric" :
166+ if pytorch_mode == "symmetric" :
168167 data_pad = _pad_symmetric (data , [(padl , padr )])
169168 else :
170- data_pad = torch .nn .functional .pad (data , [padl , padr ], mode = mode )
169+ data_pad = torch .nn .functional .pad (data , [padl , padr ], mode = pytorch_mode )
171170 return data_pad
172171
173172
@@ -263,7 +262,8 @@ def _preprocess_result_list_rec1d(
263262def wavedec (
264263 data : torch .Tensor ,
265264 wavelet : Union [Wavelet , str ],
266- mode : str = "reflect" ,
265+ * ,
266+ mode : BoundaryMode = "reflect" ,
267267 level : Optional [int ] = None ,
268268 axis : int = - 1 ,
269269) -> List [torch .Tensor ]:
@@ -276,18 +276,9 @@ def wavedec(
276276 the name of a pywt wavelet.
277277 Please consider the output from ``pywt.wavelist(kind='discrete')``
278278 for possible choices.
279- mode (str): The desired padding mode. Padding extends the signal along
280- the edges. Supported methods are::
281-
282- "reflect", "zero", "constant", "periodic", "symmetric".
283-
284- Defaults to "reflect".
285-
286- Symmetric padding mirrors samples along the border.
287- Refection padding reflects samples along the border.
288- Zero padding pads zeros.
289- Constant padding replicates border values.
290- Periodic padding cyclically repeats samples.
279+ mode :
280+ The desired padding mode for extending the signal along the edges.
281+ Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
291282 level (int): The scale level to be computed.
292283 Defaults to None.
293284 axis (int): Compute the transform over this axis instead of the
0 commit comments