11"""This module implements stationary wavelet transforms."""
22
3- from typing import List , Optional , Union
3+ from typing import List , Optional , Sequence , Union
44
55import pywt
66import torch
7+ import torch .nn .functional as F # noqa:N812
78
89from ._util import Wavelet , _as_wavelet , _unfold_axes
910from .conv_transform import (
1415)
1516
1617
17- def _swt (
18+ def _circular_pad (x : torch .Tensor , padding_dimensions : Sequence [int ]) -> torch .Tensor :
19+ """Pad a tensor in circular mode, more than once if needed."""
20+ trailing_dimension = x .shape [- 1 ]
21+
22+ # if every padding dimension is smaller than or equal the trailing dimension,
23+ # we do not need to manually wrap
24+ if not any (
25+ padding_dimension > trailing_dimension
26+ for padding_dimension in padding_dimensions
27+ ):
28+ return F .pad (x , padding_dimensions , mode = "circular" )
29+
30+ # repeat to pad at maximum trailing dimensions until all padding dimensions are zero
31+ while any (padding_dimension > 0 for padding_dimension in padding_dimensions ):
32+ # reduce every padding dimension to at maximum trailing dimension width
33+ reduced_padding_dimensions = [
34+ min (trailing_dimension , padding_dimension )
35+ for padding_dimension in padding_dimensions
36+ ]
37+ # pad using reduced dimensions,
38+ # which will never throw the circular wrap error
39+ x = F .pad (x , reduced_padding_dimensions , mode = "circular" )
40+ # remove the pad width that was just padded, and repeat
41+ # if any pad width is greater than zero
42+ padding_dimensions = [
43+ max (padding_dimension - trailing_dimension , 0 )
44+ for padding_dimension in padding_dimensions
45+ ]
46+
47+ return x
48+
49+
50+ def swt (
1851 data : torch .Tensor ,
1952 wavelet : Union [Wavelet , str ],
2053 level : Optional [int ] = None ,
2154 axis : Optional [int ] = - 1 ,
2255) -> List [torch .Tensor ]:
2356 """Compute a multilevel 1d stationary wavelet transform.
2457
58+ This fuctions is equivalent to pywt's swt with `trim_approx=True` and `norm=False`.
59+
2560 Args:
2661 data (torch.Tensor): The input data of shape [batch_size, time].
2762 wavelet (Union[Wavelet, str]): The wavelet to use.
@@ -56,57 +91,20 @@ def _swt(
5691 for current_level in range (level ):
5792 dilation = 2 ** current_level
5893 padl , padr = dilation * (filt_len // 2 - 1 ), dilation * (filt_len // 2 )
59- res_lo = torch . nn . functional . pad (res_lo , [padl , padr ], mode = "circular" )
94+ res_lo = _circular_pad (res_lo , [padl , padr ])
6095 res = torch .nn .functional .conv1d (res_lo , filt , stride = 1 , dilation = dilation )
6196 res_lo , res_hi = torch .split (res , 1 , 1 )
6297 # Trim_approx == False
6398 # result_list.append((res_lo.squeeze(1), res_hi.squeeze(1)))
6499 result_list .append (res_hi .squeeze (1 ))
65100 result_list .append (res_lo .squeeze (1 ))
66101
67- if ds :
68- result_list = _postprocess_result_list_dec1d (result_list , ds )
69-
70- if axis != - 1 :
71- result_list = [coeff .swapaxes (axis , - 1 ) for coeff in result_list ]
102+ result_list = _postprocess_result_list_dec1d (result_list , ds , axis )
72103
73104 return result_list [::- 1 ]
74105
75106
76- def _conv_transpose_dedilate (
77- conv_res : torch .Tensor ,
78- rec_filt : torch .Tensor ,
79- dilation : int ,
80- length : int ,
81- ) -> torch .Tensor :
82- """Undo the forward dilated convolution from the analysis transform.
83-
84- Args:
85- conv_res (torch.Tensor): The dilated coeffcients
86- of shape [batch, 2, length].
87- rec_filt (torch.Tensor): The reconstruction filter pair
88- of shape [1, 2, filter_length].
89- dilation (int): The dilation factor.
90- length (int): The signal length.
91-
92- Returns:
93- torch.Tensor: The deconvolution result.
94- """
95- to_conv_t_list = [
96- conv_res [..., fl : (fl + dilation * rec_filt .shape [- 1 ]) : dilation ]
97- for fl in range (length )
98- ]
99- to_conv_t = torch .cat (to_conv_t_list , 0 )
100- padding = rec_filt .shape [- 1 ] - 1
101- rec = torch .nn .functional .conv_transpose1d (
102- to_conv_t , rec_filt , stride = 1 , padding = padding , output_padding = 0
103- )
104- rec = rec / 2.0
105- splits = torch .split (rec , rec .shape [0 ] // len (to_conv_t_list ))
106- return torch .cat (splits , - 1 )
107-
108-
109- def _iswt (
107+ def iswt (
110108 coeffs : List [torch .Tensor ],
111109 wavelet : Union [pywt .Wavelet , str ],
112110 axis : Optional [int ] = - 1 ,
@@ -134,10 +132,7 @@ def _iswt(
134132 else :
135133 raise ValueError ("iswt transforms a single axis only." )
136134
137- ds = None
138- length = coeffs [0 ].shape [- 1 ]
139- if coeffs [0 ].ndim > 2 :
140- coeffs , ds = _preprocess_result_list_rec1d (coeffs )
135+ coeffs , ds = _preprocess_result_list_rec1d (coeffs )
141136
142137 wavelet = _as_wavelet (wavelet )
143138 _ , _ , rec_lo , rec_hi = _get_filter_tensors (
@@ -151,13 +146,18 @@ def _iswt(
151146 dilation = 2 ** (len (coeffs [1 :]) - c_pos - 1 )
152147 res_lo = torch .stack ([res_lo , res_hi ], 1 )
153148 padl , padr = dilation * (filt_len // 2 ), dilation * (filt_len // 2 - 1 )
154- res_lo = torch .nn .functional .pad (res_lo , (padl , padr ), mode = "circular" )
155- res_lo = _conv_transpose_dedilate (
156- res_lo , rec_filt , dilation = dilation , length = length
149+ # res_lo = torch.nn.functional.pad(res_lo, (padl, padr), mode="circular")
150+ res_lo_pad = _circular_pad (res_lo , (padl , padr ))
151+ res_lo = torch .mean (
152+ torch .nn .functional .conv_transpose1d (
153+ res_lo_pad , rec_filt , dilation = dilation , groups = 2 , padding = (padl + padr )
154+ ),
155+ 1 ,
157156 )
158- res_lo = res_lo .squeeze (1 )
159157
160- if ds :
158+ if len (ds ) == 1 :
159+ res_lo = res_lo .squeeze (0 )
160+ elif len (ds ) > 2 :
161161 res_lo = _unfold_axes (res_lo , ds , 1 )
162162
163163 if axis != - 1 :
0 commit comments