Skip to content

Commit 78abc5f

Browse files
authored
Merge pull request #84 from NiclasPi/fix-padding
SWT: Make circular padding wrap more than once if needed
2 parents 638a3c0 + 3326e2b commit 78abc5f

File tree

6 files changed

+194
-101
lines changed

6 files changed

+194
-101
lines changed

docs/ptwt.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ ptwt.separable\_conv\_transform module
5353
:undoc-members:
5454
:show-inheritance:
5555

56+
57+
ptwt.stationary\_transform module
58+
---------------------------------
59+
60+
.. automodule:: ptwt.stationary_transform
61+
:members:
62+
:undoc-members:
63+
:show-inheritance:
64+
65+
5666
ptwt.matmul\_transform module
5767
-----------------------------
5868

src/ptwt/conv_transform.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def _adjust_padding_at_reconstruction(
218218

219219
def _preprocess_tensor_dec1d(
220220
data: torch.Tensor,
221-
) -> Tuple[torch.Tensor, Union[List[int], None]]:
221+
) -> Tuple[torch.Tensor, List[int]]:
222222
"""Preprocess input tensor dimensions.
223223
224224
Args:
@@ -227,13 +227,13 @@ def _preprocess_tensor_dec1d(
227227
Returns:
228228
Tuple[torch.Tensor, Union[List[int], None]]:
229229
A data tensor of shape [new_batch, 1, to_process]
230-
and the original shape, if the shape has changed.
230+
and the original shape.
231231
"""
232-
ds = None
233-
if len(data.shape) == 1:
232+
ds = list(data.shape)
233+
if len(ds) == 1:
234234
# assume time series
235235
data = data.unsqueeze(0).unsqueeze(0)
236-
elif len(data.shape) == 2:
236+
elif len(ds) == 2:
237237
# assume batched time series
238238
data = data.unsqueeze(1)
239239
else:
@@ -243,18 +243,33 @@ def _preprocess_tensor_dec1d(
243243

244244

245245
def _postprocess_result_list_dec1d(
246-
result_lst: List[torch.Tensor], ds: List[int]
246+
result_list: List[torch.Tensor], ds: List[int], axis: int
247247
) -> List[torch.Tensor]:
248-
# Unfold axes for the wavelets
249-
return [_unfold_axes(fres, ds, 1) for fres in result_lst]
248+
if len(ds) == 1:
249+
result_list = [r_el.squeeze(0) for r_el in result_list]
250+
elif len(ds) > 2:
251+
# Unfold axes for the wavelets
252+
result_list = [_unfold_axes(fres, ds, 1) for fres in result_list]
253+
else:
254+
result_list = result_list
255+
256+
if axis != -1:
257+
result_list = [coeff.swapaxes(axis, -1) for coeff in result_list]
258+
259+
return result_list
250260

251261

252262
def _preprocess_result_list_rec1d(
253263
result_lst: List[torch.Tensor],
254264
) -> Tuple[List[torch.Tensor], List[int]]:
255265
# Fold axes for the wavelets
256266
ds = list(result_lst[0].shape)
257-
fold_coeffs = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst]
267+
if len(ds) == 1:
268+
fold_coeffs = [uf_coeff.unsqueeze(0) for uf_coeff in result_lst]
269+
elif len(ds) > 2:
270+
fold_coeffs = [_fold_axes(uf_coeff, 1)[0] for uf_coeff in result_lst]
271+
else:
272+
fold_coeffs = result_lst
258273
return fold_coeffs, ds
259274

260275

@@ -350,11 +365,7 @@ def wavedec(
350365
result_list.append(res_lo.squeeze(1))
351366
result_list.reverse()
352367

353-
if ds:
354-
result_list = _postprocess_result_list_dec1d(result_list, ds)
355-
356-
if axis != -1:
357-
result_list = [coeff.swapaxes(axis, -1) for coeff in result_list]
368+
result_list = _postprocess_result_list_dec1d(result_list, ds, axis)
358369

359370
return result_list
360371

@@ -412,9 +423,8 @@ def waverec(
412423
raise ValueError("waverec transforms a single axis only.")
413424

414425
# fold channels, if necessary.
415-
ds = None
416-
if coeffs[0].dim() >= 3:
417-
coeffs, ds = _preprocess_result_list_rec1d(coeffs)
426+
ds = list(coeffs[0].shape)
427+
coeffs, ds = _preprocess_result_list_rec1d(coeffs)
418428

419429
_, _, rec_lo, rec_hi = _get_filter_tensors(
420430
wavelet, flip=False, device=torch_device, dtype=torch_dtype
@@ -439,7 +449,9 @@ def waverec(
439449
if padr > 0:
440450
res_lo = res_lo[..., :-padr]
441451

442-
if ds:
452+
if len(ds) == 1:
453+
res_lo = res_lo.squeeze(0)
454+
elif len(ds) > 2:
443455
res_lo = _unfold_axes(res_lo, ds, 1)
444456

445457
if axis != -1:

src/ptwt/matmul_transform.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -382,15 +382,7 @@ def __call__(self, input_signal: torch.Tensor) -> List[torch.Tensor]:
382382
result_list = [s.T for s in split_list[::-1]]
383383

384384
# unfold if necessary
385-
if ds:
386-
result_list = _postprocess_result_list_dec1d(result_list, ds)
387-
388-
if self.axis != -1:
389-
swap = []
390-
for coeff in result_list:
391-
swap.append(coeff.swapaxes(self.axis, -1))
392-
result_list = swap
393-
385+
result_list = _postprocess_result_list_dec1d(result_list, ds, self.axis)
394386
return result_list
395387

396388

@@ -616,9 +608,7 @@ def __call__(self, coefficients: List[torch.Tensor]) -> torch.Tensor:
616608
swap.append(coeff.swapaxes(self.axis, -1))
617609
coefficients = swap
618610

619-
ds = None
620-
if coefficients[0].ndim > 2:
621-
coefficients, ds = _preprocess_result_list_rec1d(coefficients)
611+
coefficients, ds = _preprocess_result_list_rec1d(coefficients)
622612

623613
level = len(coefficients) - 1
624614
input_length = coefficients[-1].shape[-1] * 2
@@ -670,8 +660,10 @@ def __call__(self, coefficients: List[torch.Tensor]) -> torch.Tensor:
670660

671661
res_lo = lo.T
672662

673-
if ds:
674-
res_lo = _unfold_axes(res_lo.unsqueeze(-2), list(ds), 1).squeeze(-2)
663+
if len(ds) == 1:
664+
res_lo = res_lo.squeeze(0)
665+
elif len(ds) > 2:
666+
res_lo = _unfold_axes(res_lo, ds, 1)
675667

676668
if self.axis != -1:
677669
res_lo = res_lo.swapaxes(self.axis, -1)
Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""This module implements stationary wavelet transforms."""
22

3-
from typing import List, Optional, Union
3+
from typing import List, Optional, Sequence, Union
44

55
import pywt
66
import torch
7+
import torch.nn.functional as F # noqa:N812
78

89
from ._util import Wavelet, _as_wavelet, _unfold_axes
910
from .conv_transform import (
@@ -14,14 +15,48 @@
1415
)
1516

1617

17-
def _swt(
18+
def _circular_pad(x: torch.Tensor, padding_dimensions: Sequence[int]) -> torch.Tensor:
19+
"""Pad a tensor in circular mode, more than once if needed."""
20+
trailing_dimension = x.shape[-1]
21+
22+
# if every padding dimension is smaller than or equal the trailing dimension,
23+
# we do not need to manually wrap
24+
if not any(
25+
padding_dimension > trailing_dimension
26+
for padding_dimension in padding_dimensions
27+
):
28+
return F.pad(x, padding_dimensions, mode="circular")
29+
30+
# repeat to pad at maximum trailing dimensions until all padding dimensions are zero
31+
while any(padding_dimension > 0 for padding_dimension in padding_dimensions):
32+
# reduce every padding dimension to at maximum trailing dimension width
33+
reduced_padding_dimensions = [
34+
min(trailing_dimension, padding_dimension)
35+
for padding_dimension in padding_dimensions
36+
]
37+
# pad using reduced dimensions,
38+
# which will never throw the circular wrap error
39+
x = F.pad(x, reduced_padding_dimensions, mode="circular")
40+
# remove the pad width that was just padded, and repeat
41+
# if any pad width is greater than zero
42+
padding_dimensions = [
43+
max(padding_dimension - trailing_dimension, 0)
44+
for padding_dimension in padding_dimensions
45+
]
46+
47+
return x
48+
49+
50+
def swt(
1851
data: torch.Tensor,
1952
wavelet: Union[Wavelet, str],
2053
level: Optional[int] = None,
2154
axis: Optional[int] = -1,
2255
) -> List[torch.Tensor]:
2356
"""Compute a multilevel 1d stationary wavelet transform.
2457
58+
This fuctions is equivalent to pywt's swt with `trim_approx=True` and `norm=False`.
59+
2560
Args:
2661
data (torch.Tensor): The input data of shape [batch_size, time].
2762
wavelet (Union[Wavelet, str]): The wavelet to use.
@@ -56,57 +91,20 @@ def _swt(
5691
for current_level in range(level):
5792
dilation = 2**current_level
5893
padl, padr = dilation * (filt_len // 2 - 1), dilation * (filt_len // 2)
59-
res_lo = torch.nn.functional.pad(res_lo, [padl, padr], mode="circular")
94+
res_lo = _circular_pad(res_lo, [padl, padr])
6095
res = torch.nn.functional.conv1d(res_lo, filt, stride=1, dilation=dilation)
6196
res_lo, res_hi = torch.split(res, 1, 1)
6297
# Trim_approx == False
6398
# result_list.append((res_lo.squeeze(1), res_hi.squeeze(1)))
6499
result_list.append(res_hi.squeeze(1))
65100
result_list.append(res_lo.squeeze(1))
66101

67-
if ds:
68-
result_list = _postprocess_result_list_dec1d(result_list, ds)
69-
70-
if axis != -1:
71-
result_list = [coeff.swapaxes(axis, -1) for coeff in result_list]
102+
result_list = _postprocess_result_list_dec1d(result_list, ds, axis)
72103

73104
return result_list[::-1]
74105

75106

76-
def _conv_transpose_dedilate(
77-
conv_res: torch.Tensor,
78-
rec_filt: torch.Tensor,
79-
dilation: int,
80-
length: int,
81-
) -> torch.Tensor:
82-
"""Undo the forward dilated convolution from the analysis transform.
83-
84-
Args:
85-
conv_res (torch.Tensor): The dilated coeffcients
86-
of shape [batch, 2, length].
87-
rec_filt (torch.Tensor): The reconstruction filter pair
88-
of shape [1, 2, filter_length].
89-
dilation (int): The dilation factor.
90-
length (int): The signal length.
91-
92-
Returns:
93-
torch.Tensor: The deconvolution result.
94-
"""
95-
to_conv_t_list = [
96-
conv_res[..., fl : (fl + dilation * rec_filt.shape[-1]) : dilation]
97-
for fl in range(length)
98-
]
99-
to_conv_t = torch.cat(to_conv_t_list, 0)
100-
padding = rec_filt.shape[-1] - 1
101-
rec = torch.nn.functional.conv_transpose1d(
102-
to_conv_t, rec_filt, stride=1, padding=padding, output_padding=0
103-
)
104-
rec = rec / 2.0
105-
splits = torch.split(rec, rec.shape[0] // len(to_conv_t_list))
106-
return torch.cat(splits, -1)
107-
108-
109-
def _iswt(
107+
def iswt(
110108
coeffs: List[torch.Tensor],
111109
wavelet: Union[pywt.Wavelet, str],
112110
axis: Optional[int] = -1,
@@ -134,10 +132,7 @@ def _iswt(
134132
else:
135133
raise ValueError("iswt transforms a single axis only.")
136134

137-
ds = None
138-
length = coeffs[0].shape[-1]
139-
if coeffs[0].ndim > 2:
140-
coeffs, ds = _preprocess_result_list_rec1d(coeffs)
135+
coeffs, ds = _preprocess_result_list_rec1d(coeffs)
141136

142137
wavelet = _as_wavelet(wavelet)
143138
_, _, rec_lo, rec_hi = _get_filter_tensors(
@@ -151,13 +146,18 @@ def _iswt(
151146
dilation = 2 ** (len(coeffs[1:]) - c_pos - 1)
152147
res_lo = torch.stack([res_lo, res_hi], 1)
153148
padl, padr = dilation * (filt_len // 2), dilation * (filt_len // 2 - 1)
154-
res_lo = torch.nn.functional.pad(res_lo, (padl, padr), mode="circular")
155-
res_lo = _conv_transpose_dedilate(
156-
res_lo, rec_filt, dilation=dilation, length=length
149+
# res_lo = torch.nn.functional.pad(res_lo, (padl, padr), mode="circular")
150+
res_lo_pad = _circular_pad(res_lo, (padl, padr))
151+
res_lo = torch.mean(
152+
torch.nn.functional.conv_transpose1d(
153+
res_lo_pad, rec_filt, dilation=dilation, groups=2, padding=(padl + padr)
154+
),
155+
1,
157156
)
158-
res_lo = res_lo.squeeze(1)
159157

160-
if ds:
158+
if len(ds) == 1:
159+
res_lo = res_lo.squeeze(0)
160+
elif len(ds) > 2:
161161
res_lo = _unfold_axes(res_lo, ds, 1)
162162

163163
if axis != -1:

tests/test_convolution_fwt.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,23 @@ def test_conv_fwt1d_channel(size: List[int], wavelet: str) -> None:
8282
assert np.allclose(data.numpy(), rec.numpy())
8383

8484

85+
@pytest.mark.parametrize("size", [[32], [64]])
86+
@pytest.mark.parametrize("wavelet", ["haar", "db2"])
87+
def test_conv_fwt1d_nobatch(size: List[int], wavelet: str) -> None:
88+
"""1d conv for inputs without batch dim."""
89+
data = torch.randn(*size).type(torch.float64)
90+
ptwt_coeff = wavedec(data, wavelet)
91+
pywt_coeff = pywt.wavedec(data.numpy(), wavelet, mode="reflect")
92+
assert all(
93+
[
94+
np.allclose(ptwtc.numpy(), pywtc)
95+
for ptwtc, pywtc in zip(ptwt_coeff, pywt_coeff)
96+
]
97+
)
98+
rec = waverec(ptwt_coeff, wavelet)
99+
assert np.allclose(data.numpy(), rec.numpy())
100+
101+
85102
def test_ripples_haar_lvl3() -> None:
86103
"""Compute example from page 7 of Ripples in Mathematics, Jensen, la Cour-Harbo."""
87104

0 commit comments

Comments
 (0)