Skip to content

Commit 70fdeb1

Browse files
authored
Merge pull request #63 from v0lta/v0.1.6
V0.1.6
2 parents 6091fa0 + bdf8a30 commit 70fdeb1

25 files changed

+732
-308
lines changed

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ ignore =
2222
N400
2323
# asserts are ok in test.
2424
S101
25+
C901
2526
exclude =
2627
.tox,
2728
.git,

docs/ptwt.rst

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ ptwt.conv\_transform\_3 module
2727
:undoc-members:
2828
:show-inheritance:
2929

30-
ptwt.separable\_conv\_transform module
31-
--------------------------------------
30+
ptwt.packets module
31+
-------------------
3232

33-
.. automodule:: ptwt.separable_conv_transform
33+
.. automodule:: ptwt.packets
3434
:members:
3535
:undoc-members:
3636
:show-inheritance:
3737

38+
3839
ptwt.continuous\_transform module
3940
---------------------------------
4041

@@ -43,10 +44,11 @@ ptwt.continuous\_transform module
4344
:undoc-members:
4445
:show-inheritance:
4546

46-
ptwt.packets module
47-
-------------------
4847

49-
.. automodule:: ptwt.packets
48+
ptwt.separable\_conv\_transform module
49+
--------------------------------------
50+
51+
.. automodule:: ptwt.separable_conv_transform
5052
:members:
5153
:undoc-members:
5254
:show-inheritance:

src/ptwt/_stationary_transform.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""This module implements stationary wavelet transforms."""
2+
3+
from typing import List, Optional, Union
4+
5+
import pywt
6+
import torch
7+
8+
from src.ptwt._util import Wavelet
9+
from src.ptwt.conv_transform import _get_filter_tensors
10+
11+
12+
def swt(
13+
data: torch.Tensor,
14+
wavelet: Union[Wavelet, str],
15+
level: Optional[int] = None,
16+
) -> List[torch.Tensor]:
17+
"""Compute a multilevel 1d stationary wavelet transform.
18+
19+
Args:
20+
data (torch.Tensor): The input data of shape [batch_size, time].
21+
wavelet (Union[Wavelet, str]): The wavelet to use.
22+
level (Optional[int], optional): The number of levels to compute
23+
24+
Returns:
25+
List[torch.Tensor]: Same as wavedec.
26+
Equivalent to pywt.swt with trim_approx=True.
27+
"""
28+
if data.dim() == 1:
29+
# assume time series
30+
data = data.unsqueeze(0).unsqueeze(0)
31+
elif data.dim() == 2:
32+
# assume batched time series
33+
data = data.unsqueeze(1)
34+
35+
dec_lo, dec_hi, _, _ = _get_filter_tensors(
36+
wavelet, flip=True, device=data.device, dtype=data.dtype
37+
)
38+
filt_len = dec_lo.shape[-1]
39+
filt = torch.stack([dec_lo, dec_hi], 0)
40+
41+
if level is None:
42+
level = pywt.swt_max_level(data.shape[-1])
43+
44+
result_lst = []
45+
res_lo = data
46+
for current_level in range(level):
47+
dilation = 2**current_level
48+
padl, padr = dilation * (filt_len // 2 - 1), dilation * (filt_len // 2)
49+
res_lo = torch.nn.functional.pad(res_lo, [padl, padr], mode="circular")
50+
res = torch.nn.functional.conv1d(res_lo, filt, stride=1, dilation=dilation)
51+
res_lo, res_hi = torch.split(res, 1, 1)
52+
# Trim_approx == False
53+
# result_lst.append((res_lo.squeeze(1), res_hi.squeeze(1)))
54+
result_lst.append(res_hi.squeeze(1))
55+
result_lst.append(res_lo.squeeze(1))
56+
return result_lst[::-1]

src/ptwt/_util.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Utility methods to compute wavelet decompositions from a dataset."""
2-
from typing import Optional, Protocol, Sequence, Tuple, Union
2+
from typing import List, Optional, Protocol, Sequence, Tuple, Union
33

44
import pywt
55
import torch
@@ -64,3 +64,57 @@ def _get_len(wavelet: Union[Tuple[torch.Tensor, ...], str, Wavelet]) -> int:
6464
return wavelet[0].shape[0]
6565
else:
6666
return len(_as_wavelet(wavelet))
67+
68+
69+
def _pad_symmetric_1d(signal: torch.Tensor, pad_list: Tuple[int, int]) -> torch.Tensor:
70+
padl, padr = pad_list
71+
dimlen = signal.shape[0]
72+
if padl > dimlen or padr > dimlen:
73+
if padl > dimlen:
74+
signal = _pad_symmetric_1d(signal, (dimlen, 0))
75+
padl = padl - dimlen
76+
if padr > dimlen:
77+
signal = _pad_symmetric_1d(signal, (0, dimlen))
78+
padr = padr - dimlen
79+
return _pad_symmetric_1d(signal, (padl, padr))
80+
else:
81+
cat_list = [signal]
82+
if padl > 0:
83+
topadl = signal[:padl].flip(0)
84+
cat_list.insert(0, topadl)
85+
if padr > 0:
86+
topadr = signal[-padr::].flip(0)
87+
cat_list.append(topadr)
88+
return torch.cat(cat_list, axis=0) # type: ignore
89+
90+
91+
def _pad_symmetric(
92+
signal: torch.Tensor, pad_lists: List[Tuple[int, int]]
93+
) -> torch.Tensor:
94+
if len(signal.shape) < len(pad_lists):
95+
raise ValueError("not enough dimensions to pad.")
96+
97+
dims = len(signal.shape) - 1
98+
for pos, pad_list in enumerate(pad_lists[::-1]):
99+
current_axis = dims - pos
100+
signal = signal.transpose(0, current_axis)
101+
signal = _pad_symmetric_1d(signal, pad_list)
102+
signal = signal.transpose(current_axis, 0)
103+
return signal
104+
105+
106+
def _fold_channels(data: torch.Tensor) -> torch.Tensor:
107+
"""Fold [batch, channel, height width] into [batch*channel, height, widht]."""
108+
ds = data.shape
109+
fold_data = torch.permute(data, [2, 3, 0, 1])
110+
fold_data = torch.reshape(fold_data, [ds[2], ds[3], ds[0] * ds[1]])
111+
return torch.permute(fold_data, [2, 0, 1])
112+
113+
114+
def _unfold_channels(data: torch.Tensor, ds: List[int]) -> torch.Tensor:
115+
"""Unfold [batch*channel, height, widht] into [batch, channel, height, width]."""
116+
unfold_data = torch.permute(data, [1, 2, 0])
117+
unfold_data = torch.reshape(
118+
unfold_data, [data.shape[1], data.shape[2], ds[0], ds[1]]
119+
)
120+
return torch.permute(unfold_data, [2, 3, 0, 1])

src/ptwt/continuous_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def cwt(
2727
wavelet: Union[ContinuousWavelet, str],
2828
sampling_period: float = 1.0,
2929
) -> Tuple[torch.Tensor, np.ndarray]: # type: ignore
30-
"""Compute the single dimensional continuous wavelet transform.
30+
"""Compute the single-dimensional continuous wavelet transform.
3131
3232
This function is a PyTorch port of pywt.cwt as found at:
3333
https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py
@@ -166,7 +166,7 @@ def _integrate_wavelet(
166166
167167
Parameters
168168
----------
169-
wavelet : Wavelet instance or str
169+
wavelet: Wavelet instance or str
170170
Wavelet to integrate. If a string, should be the name of a wavelet.
171171
precision : int, optional
172172
Precision that will be used for wavelet function

src/ptwt/conv_transform.py

Lines changed: 89 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,15 @@
88
import pywt
99
import torch
1010

11-
from ._util import Wavelet, _as_wavelet, _get_len, _is_dtype_supported
11+
from ._util import (
12+
Wavelet,
13+
_as_wavelet,
14+
_fold_channels,
15+
_get_len,
16+
_is_dtype_supported,
17+
_pad_symmetric,
18+
_unfold_channels,
19+
)
1220

1321

1422
def _create_tensor(
@@ -26,7 +34,7 @@ def _create_tensor(
2634
return torch.tensor(filter, device=device, dtype=dtype).unsqueeze(0)
2735

2836

29-
def get_filter_tensors(
37+
def _get_filter_tensors(
3038
wavelet: Union[Wavelet, str],
3139
flip: bool,
3240
device: Union[torch.device, str],
@@ -116,6 +124,10 @@ def _translate_boundary_strings(pywt_mode: str) -> str:
116124
pt_mode = pywt_mode
117125
elif pywt_mode == "periodic":
118126
pt_mode = "circular"
127+
elif pywt_mode == "symmetric":
128+
# pytorch does not support symmetric mode,
129+
# we have our own implementation.
130+
pt_mode = pywt_mode
119131
else:
120132
raise ValueError("Padding mode not supported.")
121133
return pt_mode
@@ -129,12 +141,12 @@ def _fwt_pad(
129141
The padding assumes a future step will transform the last axis.
130142
131143
Args:
132-
data (torch.Tensor): Input data [batch_size, 1, time]
144+
data (torch.Tensor): Input data ``[batch_size, 1, time]``
133145
wavelet (Wavelet or str): A pywt wavelet compatible object or
134146
the name of a pywt wavelet.
135147
mode (str): The desired way to pad. The following methods are supported::
136148
137-
"reflect", "zero", "constant", "periodic".
149+
"reflect", "zero", "constant", "periodic", "symmetric".
138150
139151
Refection padding mirrors samples along the border.
140152
Zero padding pads zeros.
@@ -151,7 +163,10 @@ def _fwt_pad(
151163
mode = _translate_boundary_strings(mode)
152164

153165
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
154-
data_pad = torch.nn.functional.pad(data, [padl, padr], mode=mode)
166+
if mode == "symmetric":
167+
data_pad = _pad_symmetric(data, [(padl, padr)])
168+
else:
169+
data_pad = torch.nn.functional.pad(data, [padl, padr], mode=mode)
155170
return data_pad
156171

157172

@@ -196,10 +211,40 @@ def _adjust_padding_at_reconstruction(
196211
elif next_size == pred_size - 1:
197212
pad_end += 1
198213
else:
199-
raise AssertionError("padding error, please open an issue on github")
214+
raise AssertionError(
215+
"padding error, please check if dec and rec wavelets are identical."
216+
)
200217
return pad_end, pad_start
201218

202219

220+
def _wavedec_fold_channels_1d(data: torch.Tensor) -> Tuple[torch.Tensor, List[int]]:
221+
data = data.unsqueeze(-2)
222+
ds = data.shape
223+
data = _fold_channels(data)
224+
return data, list(ds)
225+
226+
227+
def _wavedec_unfold_channels_1d_list(
228+
result_list: List[torch.Tensor], ds: List[int]
229+
) -> List[torch.Tensor]:
230+
unfold_res = []
231+
for res_coeff in result_list:
232+
unfold_res.append(
233+
_unfold_channels(res_coeff.unsqueeze(1), list(ds)).squeeze(-2)
234+
)
235+
return unfold_res
236+
237+
238+
def _waverec_fold_channels_1d_list(
239+
coeff_list: List[torch.Tensor],
240+
) -> Tuple[List[torch.Tensor], List[int]]:
241+
folded = []
242+
ds = coeff_list[0].unsqueeze(-2).shape
243+
for to_fold_coeff in coeff_list:
244+
folded.append(_fold_channels(to_fold_coeff.unsqueeze(-2)).squeeze(-2))
245+
return folded, list(ds)
246+
247+
203248
def wavedec(
204249
data: torch.Tensor,
205250
wavelet: Union[Wavelet, str],
@@ -209,19 +254,27 @@ def wavedec(
209254
"""Compute the analysis (forward) 1d fast wavelet transform.
210255
211256
Args:
212-
data (torch.Tensor): Input time series of shape [batch_size, 1, time]
213-
1d inputs are interpreted as [time],
214-
2d inputs are interpreted as [batch_size, time].
257+
data (torch.Tensor): The input time series,
258+
1d inputs are interpreted as ``[time]``,
259+
2d inputs as ``[batch_size, time]``,
260+
and 3d inputs as ``[batch_size, channels, time]``.
215261
wavelet (Wavelet or str): A pywt wavelet compatible object or
216262
the name of a pywt wavelet.
217263
Please consider the output from ``pywt.wavelist(kind='discrete')``
218264
for possible choices.
219265
mode (str): The desired padding mode. Padding extends the signal along
220266
the edges. Supported methods are::
221267
222-
"reflect", "zero", "constant", "periodic".
268+
"reflect", "zero", "constant", "periodic", "symmetric".
223269
224270
Defaults to "reflect".
271+
272+
Symmetric padding mirrors samples along the border.
273+
Refection padding reflects samples along the border.
274+
Zero padding pads zeros.
275+
Constant padding replicates border values.
276+
Periodic padding cyclically repeats samples.
277+
225278
level (int): The scale level to be computed.
226279
Defaults to None.
227280
@@ -246,19 +299,23 @@ def wavedec(
246299
>>> # compute the forward fwt coefficients
247300
>>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'),
248301
>>> mode='zero', level=2)
249-
250302
"""
303+
fold = False
251304
if data.dim() == 1:
252305
# assume time series
253306
data = data.unsqueeze(0).unsqueeze(0)
254307
elif data.dim() == 2:
255308
# assume batched time series
256309
data = data.unsqueeze(1)
310+
elif data.dim() == 3:
311+
# assume batch, channels, time -> fold channels
312+
fold = True
313+
data, ds = _wavedec_fold_channels_1d(data)
257314

258315
if not _is_dtype_supported(data.dtype):
259316
raise ValueError(f"Input dtype {data.dtype} not supported")
260317

261-
dec_lo, dec_hi, _, _ = get_filter_tensors(
318+
dec_lo, dec_hi, _, _ = _get_filter_tensors(
262319
wavelet, flip=True, device=data.device, dtype=data.dtype
263320
)
264321
filt_len = dec_lo.shape[-1]
@@ -267,15 +324,20 @@ def wavedec(
267324
if level is None:
268325
level = pywt.dwt_max_level(data.shape[-1], filt_len)
269326

270-
result_lst = []
327+
result_list = []
271328
res_lo = data
272329
for _ in range(level):
273330
res_lo = _fwt_pad(res_lo, wavelet, mode=mode)
274331
res = torch.nn.functional.conv1d(res_lo, filt, stride=2)
275332
res_lo, res_hi = torch.split(res, 1, 1)
276-
result_lst.append(res_hi.squeeze(1))
277-
result_lst.append(res_lo.squeeze(1))
278-
return result_lst[::-1]
333+
result_list.append(res_hi.squeeze(1))
334+
result_list.append(res_lo.squeeze(1))
335+
336+
# unfold if necessary
337+
if fold:
338+
result_list = _wavedec_unfold_channels_1d_list(result_list, ds)
339+
340+
return result_list[::-1]
279341

280342

281343
def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.Tensor:
@@ -317,7 +379,13 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T
317379
elif torch_dtype != coeff.dtype:
318380
raise ValueError("coefficients must have the same dtype")
319381

320-
_, _, rec_lo, rec_hi = get_filter_tensors(
382+
# fold channels, if necessary.
383+
fold = False
384+
if coeffs[0].dim() == 3:
385+
fold = True
386+
coeffs, ds = _waverec_fold_channels_1d_list(coeffs)
387+
388+
_, _, rec_lo, rec_hi = _get_filter_tensors(
321389
wavelet, flip=False, device=torch_device, dtype=torch_dtype
322390
)
323391
filt_len = rec_lo.shape[-1]
@@ -339,4 +407,8 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T
339407
res_lo = res_lo[..., padl:]
340408
if padr > 0:
341409
res_lo = res_lo[..., :-padr]
410+
411+
if fold:
412+
res_lo = _unfold_channels(res_lo.unsqueeze(-2), list(ds)).squeeze(-2)
413+
342414
return res_lo

0 commit comments

Comments
 (0)