88import pywt
99import torch
1010
11- from ._util import Wavelet , _as_wavelet , _get_len , _is_dtype_supported
11+ from ._util import (
12+ Wavelet ,
13+ _as_wavelet ,
14+ _fold_channels ,
15+ _get_len ,
16+ _is_dtype_supported ,
17+ _pad_symmetric ,
18+ _unfold_channels ,
19+ )
1220
1321
1422def _create_tensor (
@@ -26,7 +34,7 @@ def _create_tensor(
2634 return torch .tensor (filter , device = device , dtype = dtype ).unsqueeze (0 )
2735
2836
29- def get_filter_tensors (
37+ def _get_filter_tensors (
3038 wavelet : Union [Wavelet , str ],
3139 flip : bool ,
3240 device : Union [torch .device , str ],
@@ -116,6 +124,10 @@ def _translate_boundary_strings(pywt_mode: str) -> str:
116124 pt_mode = pywt_mode
117125 elif pywt_mode == "periodic" :
118126 pt_mode = "circular"
127+ elif pywt_mode == "symmetric" :
128+ # pytorch does not support symmetric mode,
129+ # we have our own implementation.
130+ pt_mode = pywt_mode
119131 else :
120132 raise ValueError ("Padding mode not supported." )
121133 return pt_mode
@@ -129,12 +141,12 @@ def _fwt_pad(
129141 The padding assumes a future step will transform the last axis.
130142
131143 Args:
132- data (torch.Tensor): Input data [batch_size, 1, time]
144+ data (torch.Tensor): Input data `` [batch_size, 1, time]``
133145 wavelet (Wavelet or str): A pywt wavelet compatible object or
134146 the name of a pywt wavelet.
135147 mode (str): The desired way to pad. The following methods are supported::
136148
137- "reflect", "zero", "constant", "periodic".
149+ "reflect", "zero", "constant", "periodic", "symmetric" .
138150
139151 Refection padding mirrors samples along the border.
140152 Zero padding pads zeros.
@@ -151,7 +163,10 @@ def _fwt_pad(
151163 mode = _translate_boundary_strings (mode )
152164
153165 padr , padl = _get_pad (data .shape [- 1 ], _get_len (wavelet ))
154- data_pad = torch .nn .functional .pad (data , [padl , padr ], mode = mode )
166+ if mode == "symmetric" :
167+ data_pad = _pad_symmetric (data , [(padl , padr )])
168+ else :
169+ data_pad = torch .nn .functional .pad (data , [padl , padr ], mode = mode )
155170 return data_pad
156171
157172
@@ -196,10 +211,40 @@ def _adjust_padding_at_reconstruction(
196211 elif next_size == pred_size - 1 :
197212 pad_end += 1
198213 else :
199- raise AssertionError ("padding error, please open an issue on github" )
214+ raise AssertionError (
215+ "padding error, please check if dec and rec wavelets are identical."
216+ )
200217 return pad_end , pad_start
201218
202219
220+ def _wavedec_fold_channels_1d (data : torch .Tensor ) -> Tuple [torch .Tensor , List [int ]]:
221+ data = data .unsqueeze (- 2 )
222+ ds = data .shape
223+ data = _fold_channels (data )
224+ return data , list (ds )
225+
226+
227+ def _wavedec_unfold_channels_1d_list (
228+ result_list : List [torch .Tensor ], ds : List [int ]
229+ ) -> List [torch .Tensor ]:
230+ unfold_res = []
231+ for res_coeff in result_list :
232+ unfold_res .append (
233+ _unfold_channels (res_coeff .unsqueeze (1 ), list (ds )).squeeze (- 2 )
234+ )
235+ return unfold_res
236+
237+
238+ def _waverec_fold_channels_1d_list (
239+ coeff_list : List [torch .Tensor ],
240+ ) -> Tuple [List [torch .Tensor ], List [int ]]:
241+ folded = []
242+ ds = coeff_list [0 ].unsqueeze (- 2 ).shape
243+ for to_fold_coeff in coeff_list :
244+ folded .append (_fold_channels (to_fold_coeff .unsqueeze (- 2 )).squeeze (- 2 ))
245+ return folded , list (ds )
246+
247+
203248def wavedec (
204249 data : torch .Tensor ,
205250 wavelet : Union [Wavelet , str ],
@@ -209,19 +254,27 @@ def wavedec(
209254 """Compute the analysis (forward) 1d fast wavelet transform.
210255
211256 Args:
212- data (torch.Tensor): Input time series of shape [batch_size, 1, time]
213- 1d inputs are interpreted as [time],
214- 2d inputs are interpreted as [batch_size, time].
257+ data (torch.Tensor): The input time series,
258+ 1d inputs are interpreted as ``[time]``,
259+ 2d inputs as ``[batch_size, time]``,
260+ and 3d inputs as ``[batch_size, channels, time]``.
215261 wavelet (Wavelet or str): A pywt wavelet compatible object or
216262 the name of a pywt wavelet.
217263 Please consider the output from ``pywt.wavelist(kind='discrete')``
218264 for possible choices.
219265 mode (str): The desired padding mode. Padding extends the signal along
220266 the edges. Supported methods are::
221267
222- "reflect", "zero", "constant", "periodic".
268+ "reflect", "zero", "constant", "periodic", "symmetric" .
223269
224270 Defaults to "reflect".
271+
272+ Symmetric padding mirrors samples along the border.
273+ Refection padding reflects samples along the border.
274+ Zero padding pads zeros.
275+ Constant padding replicates border values.
276+ Periodic padding cyclically repeats samples.
277+
225278 level (int): The scale level to be computed.
226279 Defaults to None.
227280
@@ -246,19 +299,23 @@ def wavedec(
246299 >>> # compute the forward fwt coefficients
247300 >>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'),
248301 >>> mode='zero', level=2)
249-
250302 """
303+ fold = False
251304 if data .dim () == 1 :
252305 # assume time series
253306 data = data .unsqueeze (0 ).unsqueeze (0 )
254307 elif data .dim () == 2 :
255308 # assume batched time series
256309 data = data .unsqueeze (1 )
310+ elif data .dim () == 3 :
311+ # assume batch, channels, time -> fold channels
312+ fold = True
313+ data , ds = _wavedec_fold_channels_1d (data )
257314
258315 if not _is_dtype_supported (data .dtype ):
259316 raise ValueError (f"Input dtype { data .dtype } not supported" )
260317
261- dec_lo , dec_hi , _ , _ = get_filter_tensors (
318+ dec_lo , dec_hi , _ , _ = _get_filter_tensors (
262319 wavelet , flip = True , device = data .device , dtype = data .dtype
263320 )
264321 filt_len = dec_lo .shape [- 1 ]
@@ -267,15 +324,20 @@ def wavedec(
267324 if level is None :
268325 level = pywt .dwt_max_level (data .shape [- 1 ], filt_len )
269326
270- result_lst = []
327+ result_list = []
271328 res_lo = data
272329 for _ in range (level ):
273330 res_lo = _fwt_pad (res_lo , wavelet , mode = mode )
274331 res = torch .nn .functional .conv1d (res_lo , filt , stride = 2 )
275332 res_lo , res_hi = torch .split (res , 1 , 1 )
276- result_lst .append (res_hi .squeeze (1 ))
277- result_lst .append (res_lo .squeeze (1 ))
278- return result_lst [::- 1 ]
333+ result_list .append (res_hi .squeeze (1 ))
334+ result_list .append (res_lo .squeeze (1 ))
335+
336+ # unfold if necessary
337+ if fold :
338+ result_list = _wavedec_unfold_channels_1d_list (result_list , ds )
339+
340+ return result_list [::- 1 ]
279341
280342
281343def waverec (coeffs : List [torch .Tensor ], wavelet : Union [Wavelet , str ]) -> torch .Tensor :
@@ -317,7 +379,13 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T
317379 elif torch_dtype != coeff .dtype :
318380 raise ValueError ("coefficients must have the same dtype" )
319381
320- _ , _ , rec_lo , rec_hi = get_filter_tensors (
382+ # fold channels, if necessary.
383+ fold = False
384+ if coeffs [0 ].dim () == 3 :
385+ fold = True
386+ coeffs , ds = _waverec_fold_channels_1d_list (coeffs )
387+
388+ _ , _ , rec_lo , rec_hi = _get_filter_tensors (
321389 wavelet , flip = False , device = torch_device , dtype = torch_dtype
322390 )
323391 filt_len = rec_lo .shape [- 1 ]
@@ -339,4 +407,8 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T
339407 res_lo = res_lo [..., padl :]
340408 if padr > 0 :
341409 res_lo = res_lo [..., :- padr ]
410+
411+ if fold :
412+ res_lo = _unfold_channels (res_lo .unsqueeze (- 2 ), list (ds )).squeeze (- 2 )
413+
342414 return res_lo
0 commit comments