1111from ._util import (
1212 Wavelet ,
1313 _as_wavelet ,
14- _fold_channels ,
14+ _fold_axes ,
1515 _get_len ,
1616 _is_dtype_supported ,
1717 _pad_symmetric ,
18- _unfold_channels ,
18+ _unfold_axes ,
1919)
2020
2121
@@ -217,47 +217,66 @@ def _adjust_padding_at_reconstruction(
217217 return pad_end , pad_start
218218
219219
220- def _wavedec_fold_channels_1d (data : torch .Tensor ) -> Tuple [torch .Tensor , List [int ]]:
221- data = data .unsqueeze (- 2 )
222- ds = data .shape
223- data = _fold_channels (data )
224- return data , list (ds )
220+ def _preprocess_tensor_dec1d (
221+ data : torch .Tensor ,
222+ ) -> Tuple [torch .Tensor , Union [List [int ], None ]]:
223+ """Preprocess input tensor dimensions.
225224
225+ Args:
226+ data (torch.Tensor): An input tensor of any shape.
226227
227- def _wavedec_unfold_channels_1d_list (
228- result_list : List [torch .Tensor ], ds : List [int ]
228+ Returns:
229+ Tuple[torch.Tensor, Union[List[int], None]]:
230+ A data tensor of shape [new_batch, 1, to_process]
231+ and the original shape, if the shape has changed.
232+ """
233+ ds = None
234+ if len (data .shape ) == 1 :
235+ # assume time series
236+ data = data .unsqueeze (0 ).unsqueeze (0 )
237+ elif len (data .shape ) == 2 :
238+ # assume batched time series
239+ data = data .unsqueeze (1 )
240+ else :
241+ data , ds = _fold_axes (data , 1 )
242+ data = data .unsqueeze (1 )
243+ return data , ds
244+
245+
246+ def _postprocess_result_list_dec1d (
247+ result_lst : List [torch .Tensor ], ds : List [int ]
229248) -> List [torch .Tensor ]:
230- unfold_res = []
231- for res_coeff in result_list :
232- unfold_res .append (
233- _unfold_channels (res_coeff .unsqueeze (1 ), list (ds )).squeeze (- 2 )
234- )
235- return unfold_res
249+ # Unfold axes for the wavelets
250+ unfold_list = []
251+ for fres in result_lst :
252+ unfold_list .append (_unfold_axes (fres , ds , 1 ))
253+ return unfold_list
236254
237255
238- def _waverec_fold_channels_1d_list (
239- coeff_list : List [torch .Tensor ],
256+ def _preprocess_result_list_rec1d (
257+ result_lst : List [torch .Tensor ],
240258) -> Tuple [List [torch .Tensor ], List [int ]]:
241- folded = []
242- ds = coeff_list [0 ].unsqueeze (- 2 ).shape
243- for to_fold_coeff in coeff_list :
244- folded .append (_fold_channels (to_fold_coeff .unsqueeze (- 2 )).squeeze (- 2 ))
245- return folded , list (ds )
259+ # Fold axes for the wavelets
260+ fold_coeffs = []
261+ ds = list (result_lst [0 ].shape )
262+ for uf_coeff in result_lst :
263+ f_coeff , _ = _fold_axes (uf_coeff , 1 )
264+ fold_coeffs .append (f_coeff )
265+ return fold_coeffs , ds
246266
247267
248268def wavedec (
249269 data : torch .Tensor ,
250270 wavelet : Union [Wavelet , str ],
251271 mode : str = "reflect" ,
252272 level : Optional [int ] = None ,
273+ axis : int = - 1 ,
253274) -> List [torch .Tensor ]:
254275 """Compute the analysis (forward) 1d fast wavelet transform.
255276
256277 Args:
257278 data (torch.Tensor): The input time series,
258- 1d inputs are interpreted as ``[time]``,
259- 2d inputs as ``[batch_size, time]``,
260- and 3d inputs as ``[batch_size, channels, time]``.
279+ By default the last axis is transformed.
261280 wavelet (Wavelet or str): A pywt wavelet compatible object or
262281 the name of a pywt wavelet.
263282 Please consider the output from ``pywt.wavelist(kind='discrete')``
@@ -274,9 +293,11 @@ def wavedec(
274293 Zero padding pads zeros.
275294 Constant padding replicates border values.
276295 Periodic padding cyclically repeats samples.
277-
278296 level (int): The scale level to be computed.
279297 Defaults to None.
298+ axis (int): Compute the transform over this axis instead of the
299+ last one. Defaults to -1.
300+
280301
281302 Returns:
282303 list: A list::
@@ -287,7 +308,8 @@ def wavedec(
287308 approximation and D detail coefficients.
288309
289310 Raises:
290- ValueError: If the dtype of the input data tensor is unsupported.
311+ ValueError: If the dtype of the input data tensor is unsupported or
312+ if more than one axis is provided.
291313
292314 Example:
293315 >>> import torch
@@ -300,21 +322,17 @@ def wavedec(
300322 >>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'),
301323 >>> mode='zero', level=2)
302324 """
303- fold = False
304- if data .dim () == 1 :
305- # assume time series
306- data = data .unsqueeze (0 ).unsqueeze (0 )
307- elif data .dim () == 2 :
308- # assume batched time series
309- data = data .unsqueeze (1 )
310- elif data .dim () == 3 :
311- # assume batch, channels, time -> fold channels
312- fold = True
313- data , ds = _wavedec_fold_channels_1d (data )
325+ if axis != - 1 :
326+ if isinstance (axis , int ):
327+ data = data .swapaxes (axis , - 1 )
328+ else :
329+ raise ValueError ("wavedec transforms a single axis only." )
314330
315331 if not _is_dtype_supported (data .dtype ):
316332 raise ValueError (f"Input dtype { data .dtype } not supported" )
317333
334+ data , ds = _preprocess_tensor_dec1d (data )
335+
318336 dec_lo , dec_hi , _ , _ = _get_filter_tensors (
319337 wavelet , flip = True , device = data .device , dtype = data .dtype
320338 )
@@ -332,28 +350,38 @@ def wavedec(
332350 res_lo , res_hi = torch .split (res , 1 , 1 )
333351 result_list .append (res_hi .squeeze (1 ))
334352 result_list .append (res_lo .squeeze (1 ))
353+ result_list .reverse ()
354+
355+ if ds :
356+ result_list = _postprocess_result_list_dec1d (result_list , ds )
335357
336- # unfold if necessary
337- if fold :
338- result_list = _wavedec_unfold_channels_1d_list (result_list , ds )
358+ if axis != - 1 :
359+ swap = []
360+ for coeff in result_list :
361+ swap .append (coeff .swapaxes (axis , - 1 ))
362+ result_list = swap
339363
340- return result_list [:: - 1 ]
364+ return result_list
341365
342366
343- def waverec (coeffs : List [torch .Tensor ], wavelet : Union [Wavelet , str ]) -> torch .Tensor :
367+ def waverec (
368+ coeffs : List [torch .Tensor ], wavelet : Union [Wavelet , str ], axis : int = - 1
369+ ) -> torch .Tensor :
344370 """Reconstruct a signal from wavelet coefficients.
345371
346372 Args:
347373 coeffs (list): The wavelet coefficient list produced by wavedec.
348374 wavelet (Wavelet or str): A pywt wavelet compatible object or
349375 the name of a pywt wavelet.
376+ axis (int): Transform this axis instead of the last one. Defaults to -1.
350377
351378 Returns:
352379 torch.Tensor: The reconstructed signal.
353380
354381 Raises:
355382 ValueError: If the dtype of the coeffs tensor is unsupported or if the
356- coefficients have incompatible shapes, dtypes or devices.
383+ coefficients have incompatible shapes, dtypes or devices or if
384+ more than one axis is provided.
357385
358386 Example:
359387 >>> import torch
@@ -379,11 +407,19 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T
379407 elif torch_dtype != coeff .dtype :
380408 raise ValueError ("coefficients must have the same dtype" )
381409
410+ if axis != - 1 :
411+ swap = []
412+ if isinstance (axis , int ):
413+ for coeff in coeffs :
414+ swap .append (coeff .swapaxes (axis , - 1 ))
415+ coeffs = swap
416+ else :
417+ raise ValueError ("waverec transforms a single axis only." )
418+
382419 # fold channels, if necessary.
383- fold = False
384- if coeffs [0 ].dim () == 3 :
385- fold = True
386- coeffs , ds = _waverec_fold_channels_1d_list (coeffs )
420+ ds = None
421+ if coeffs [0 ].dim () >= 3 :
422+ coeffs , ds = _preprocess_result_list_rec1d (coeffs )
387423
388424 _ , _ , rec_lo , rec_hi = _get_filter_tensors (
389425 wavelet , flip = False , device = torch_device , dtype = torch_dtype
@@ -408,7 +444,10 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T
408444 if padr > 0 :
409445 res_lo = res_lo [..., :- padr ]
410446
411- if fold :
412- res_lo = _unfold_channels (res_lo .unsqueeze (- 2 ), list (ds )).squeeze (- 2 )
447+ if ds :
448+ res_lo = _unfold_axes (res_lo , ds , 1 )
449+
450+ if axis != - 1 :
451+ res_lo = res_lo .swapaxes (axis , - 1 )
413452
414453 return res_lo
0 commit comments