Skip to content

Commit 38336ad

Browse files
authored
Merge pull request #36 from v0lta/inverse_packets
Inverse packets
2 parents a1f35ba + 0a5c49e commit 38336ad

File tree

3 files changed

+245
-45
lines changed

3 files changed

+245
-45
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Welcome to the PyTorch wavelet toolbox. This package implements:
3939
- ``MatrixWavedec`` and ``MatrixWaverec`` provide sparse-matrix-based fast wavelet transforms with boundary filters,
4040
- 2d sparse-matrix transforms with separable & non-separable boundary filters are available (experimental),
4141
- ``cwt`` computes a one-dimensional continuous forward transform,
42-
- single and two-dimensional wavelet packet forward transforms are available via the ``WaveletPacket`` and ``WaveletPacket2D`` objects,
42+
- single and two-dimensional wavelet packet forward and backward transforms are available via the ``WaveletPacket`` and ``WaveletPacket2D`` objects,
4343
- finally, this package provides adaptive wavelet support (experimental).
4444

4545
This toolbox supports pywt-wavelets. Complete documentation is available:

src/ptwt/packets.py

Lines changed: 169 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
import torch
1111

1212
from ._util import Wavelet, _as_wavelet
13-
from .conv_transform import wavedec
14-
from .conv_transform_2 import wavedec2
15-
from .matmul_transform import MatrixWavedec
16-
from .matmul_transform_2 import MatrixWavedec2
13+
from .conv_transform import wavedec, waverec
14+
from .conv_transform_2 import wavedec2, waverec2
15+
from .matmul_transform import MatrixWavedec, MatrixWaverec
16+
from .matmul_transform_2 import MatrixWavedec2, MatrixWaverec2
1717

1818
if TYPE_CHECKING:
1919
BaseDict = collections.UserDict[str, torch.Tensor]
@@ -30,7 +30,7 @@ def __init__(
3030
wavelet: Union[Wavelet, str],
3131
mode: str = "reflect",
3232
boundary_orthogonalization: str = "qr",
33-
max_level: Optional[int] = None,
33+
maxlevel: Optional[int] = None,
3434
) -> None:
3535
"""Create a wavelet packet decomposition object.
3636
@@ -47,42 +47,91 @@ def __init__(
4747
boundary_orthogonalization (str): The orthogonalization method
4848
to use. Only used if `mode` equals 'boundary'. Choose from
4949
'qr' or 'gramschmidt'. Defaults to 'qr'.
50-
max_level (int, optional): Value is passed on to `transform`.
50+
maxlevel (int, optional): Value is passed on to `transform`.
5151
The highest decomposition level to compute. If None, the maximum level
5252
is determined from the input data shape. Defaults to None.
53+
54+
Example:
55+
>>> import torch, pywt, ptwt
56+
>>> import numpy as np
57+
>>> import scipy.signal
58+
>>> import matplotlib.pyplot as plt
59+
>>> t = np.linspace(0, 10, 1500)
60+
>>> w = scipy.signal.chirp(t, f0=1, f1=50, t1=10, method="linear")
61+
>>> wp = ptwt.WaveletPacket(data=torch.from_numpy(w.astype(np.float32)),
62+
wavelet=pywt.Wavelet("db3"), mode="reflect")
63+
>>> np_lst = []
64+
>>> for node in wp.get_level(5):
65+
>>> np_lst.append(wp[node])
66+
>>> viz = np.stack(np_lst).squeeze()
67+
>>> plt.imshow(np.abs(viz))
68+
>>> plt.show()
69+
5370
"""
5471
self.wavelet = _as_wavelet(wavelet)
5572
self.mode = mode
5673
self.boundary = boundary_orthogonalization
5774
self._matrix_wavedec_dict: Dict[int, MatrixWavedec] = {}
58-
self.max_level: Optional[int] = None
75+
self._matrix_waverec_dict: Dict[int, MatrixWaverec] = {}
76+
self.maxlevel: Optional[int] = None
5977
if data is not None:
6078
if len(data.shape) == 1:
6179
# add a batch dimension.
6280
data = data.unsqueeze(0)
63-
self.transform(data, max_level)
81+
self.transform(data, maxlevel)
6482
else:
6583
self.data = {}
6684

6785
def transform(
68-
self, data: torch.Tensor, max_level: Optional[int] = None
86+
self, data: torch.Tensor, maxlevel: Optional[int] = None
6987
) -> "WaveletPacket":
7088
"""Calculate the 1d wavelet packet transform for the input data.
7189
7290
Args:
7391
data (torch.Tensor): The input data array of shape [time]
7492
or [batch_size, time].
75-
max_level (int, optional): The highest decomposition level to compute.
93+
maxlevel (int, optional): The highest decomposition level to compute.
7694
If None, the maximum level is determined from the input data shape.
7795
Defaults to None.
7896
"""
7997
self.data = {}
80-
if max_level is None:
81-
max_level = pywt.dwt_max_level(data.shape[-1], self.wavelet.dec_len)
82-
self.max_level = max_level
98+
if maxlevel is None:
99+
maxlevel = pywt.dwt_max_level(data.shape[-1], self.wavelet.dec_len)
100+
self.maxlevel = maxlevel
83101
self._recursive_dwt(data, level=0, path="")
84102
return self
85103

104+
def reconstruct(self) -> "WaveletPacket":
105+
"""Recursively reconstruct the input starting from the leaf nodes.
106+
107+
Reconstruction replaces the input-data originally assigned to this object.
108+
109+
Note:
110+
Only changes to leaf node data impacts the results,
111+
since changes in all other nodes will be replaced with
112+
a reconstruction from the leafs.
113+
114+
Example:
115+
>>> import numpy as np
116+
>>> import ptwt, torch
117+
>>> signal = np.random.randn(1, 16)
118+
>>> ptwp = ptwt.WaveletPacket(torch.from_numpy(signal), "haar",
119+
mode="boundary", maxlevel=2)
120+
>>> ptwp["aa"].data *= 0
121+
>>> ptwp.reconstruct()
122+
>>> print(ptwp[""])
123+
"""
124+
if self.maxlevel is None:
125+
self.maxlevel = pywt.dwt_max_level(self[""].shape[-1], self.wavelet.dec_len)
126+
127+
for level in reversed(range(self.maxlevel)):
128+
for node in self.get_level(level):
129+
data_a = self[node + "a"]
130+
data_b = self[node + "d"]
131+
rec = self._get_waverec(data_a.shape[-1])([data_a, data_b])
132+
self[node] = rec
133+
return self
134+
86135
def _get_wavedec(
87136
self,
88137
length: int,
@@ -96,6 +145,19 @@ def _get_wavedec(
96145
else:
97146
return partial(wavedec, wavelet=self.wavelet, level=1, mode=self.mode)
98147

148+
def _get_waverec(
149+
self,
150+
length: int,
151+
) -> Callable[[List[torch.Tensor]], torch.Tensor]:
152+
if self.mode == "boundary":
153+
if length not in self._matrix_waverec_dict.keys():
154+
self._matrix_waverec_dict[length] = MatrixWaverec(
155+
self.wavelet, boundary=self.boundary
156+
)
157+
return self._matrix_waverec_dict[length]
158+
else:
159+
return partial(waverec, wavelet=self.wavelet)
160+
99161
def get_level(self, level: int) -> List[str]:
100162
"""Return the graycode ordered paths to the filter tree nodes.
101163
@@ -113,10 +175,13 @@ def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> List[st
113175
graycode_order = [x + path for path in graycode_order] + [
114176
y + path for path in graycode_order[::-1]
115177
]
116-
return graycode_order
178+
if level == 0:
179+
return [""]
180+
else:
181+
return graycode_order
117182

118183
def _recursive_dwt(self, data: torch.Tensor, level: int, path: str) -> None:
119-
if not self.max_level:
184+
if not self.maxlevel:
120185
raise AssertionError
121186

122187
# TODO: This is a workaround since the convolutional transforms insert a
@@ -125,7 +190,7 @@ def _recursive_dwt(self, data: torch.Tensor, level: int, path: str) -> None:
125190
data = data.squeeze(1)
126191

127192
self.data[path] = data
128-
if level < self.max_level:
193+
if level < self.maxlevel:
129194
res_lo, res_hi = self._get_wavedec(data.shape[-1])(data)
130195
self._recursive_dwt(res_lo, level + 1, path + "a")
131196
self._recursive_dwt(res_hi, level + 1, path + "d")
@@ -144,22 +209,26 @@ def __getitem__(self, key: str) -> torch.Tensor:
144209
ValueError: If the wavelet packet tree is not initialized.
145210
KeyError: If no wavelet coefficients are indexed by the specified key.
146211
"""
147-
if self.max_level is None:
212+
if self.maxlevel is None:
148213
raise ValueError(
149214
"The wavelet packet tree must be initialized via 'transform' before "
150215
"its values can be accessed!"
151216
)
152-
if key not in self and len(key) > self.max_level:
217+
if key not in self and len(key) > self.maxlevel:
153218
raise KeyError(
154219
f"The requested level {len(key)} with key '{key}' is too large and "
155220
"cannot be accessed! This wavelet packet tree is initialized with "
156-
f"maximum level {self.max_level}."
221+
f"maximum level {self.maxlevel}."
157222
)
158223
return super().__getitem__(key)
159224

160225

161226
class WaveletPacket2D(BaseDict):
162-
"""Two dimensional wavelet packets."""
227+
"""Two dimensional wavelet packets.
228+
229+
Example code illustrating the use of this class is available at:
230+
https://github.com/v0lta/PyTorch-Wavelet-Toolbox/tree/main/examples/deepfake_analysis
231+
"""
163232

164233
def __init__(
165234
self,
@@ -168,7 +237,7 @@ def __init__(
168237
mode: str = "reflect",
169238
boundary_orthogonalization: str = "qr",
170239
separable: bool = False,
171-
max_level: Optional[int] = None,
240+
maxlevel: Optional[int] = None,
172241
) -> None:
173242
"""Create a 2D-Wavelet packet tree.
174243
@@ -188,7 +257,7 @@ def __init__(
188257
separable (bool): If true and the sparse matrix backend is selected,
189258
a separable transform is performed, i.e. each image axis is
190259
transformed separately. Defaults to False.
191-
max_level (int, optional): Value is passed on to `transform`.
260+
maxlevel (int, optional): Value is passed on to `transform`.
192261
The highest decomposition level to compute. If None, the maximum level
193262
is determined from the input data shape. Defaults to None.
194263
"""
@@ -197,15 +266,16 @@ def __init__(
197266
self.boundary = boundary_orthogonalization
198267
self.separable = separable
199268
self.matrix_wavedec2_dict: Dict[Tuple[int, ...], MatrixWavedec2] = {}
269+
self.matrix_waverec2_dict: Dict[Tuple[int, ...], MatrixWaverec2] = {}
200270

201-
self.max_level: Optional[int] = None
271+
self.maxlevel: Optional[int] = None
202272
if data is not None:
203-
self.transform(data, max_level)
273+
self.transform(data, maxlevel)
204274
else:
205275
self.data = {}
206276

207277
def transform(
208-
self, data: torch.Tensor, max_level: Optional[int] = None
278+
self, data: torch.Tensor, maxlevel: Optional[int] = None
209279
) -> "WaveletPacket2D":
210280
"""Calculate the 2d wavelet packet transform for the input data.
211281
@@ -214,18 +284,64 @@ def transform(
214284
Args:
215285
data (torch.tensor): The input data tensor
216286
of shape [batch_size, height, width]
217-
max_level (int, optional): The highest decomposition level to compute.
287+
maxlevel (int, optional): The highest decomposition level to compute.
218288
If None, the maximum level is determined from the input data shape.
219289
Defaults to None.
220290
"""
221291
self.data = {}
222-
if max_level is None:
223-
max_level = pywt.dwt_max_level(min(data.shape[-2:]), self.wavelet.dec_len)
224-
self.max_level = max_level
292+
if maxlevel is None:
293+
maxlevel = pywt.dwt_max_level(min(data.shape[-2:]), self.wavelet.dec_len)
294+
self.maxlevel = maxlevel
225295

226296
self._recursive_dwt2d(data, level=0, path="")
227297
return self
228298

299+
def reconstruct(self) -> "WaveletPacket2D":
300+
"""Recursively reconstruct the input starting from the leaf nodes.
301+
302+
Note:
303+
Only changes to leaf node data impacts the results,
304+
since changes in all other nodes will be replaced with
305+
a reconstruction from the leafs.
306+
"""
307+
if self.maxlevel is None:
308+
self.maxlevel = pywt.dwt_max_level(
309+
min(self[""].shape[-2:]), self.wavelet.dec_len
310+
)
311+
312+
for level in reversed(range(self.maxlevel)):
313+
for node in self.get_natural_order(level):
314+
if self.mode == "boundary":
315+
data_a = self[node + "a"]
316+
data_h = self[node + "h"]
317+
data_v = self[node + "v"]
318+
data_d = self[node + "d"]
319+
rec = self._get_waverec(data_a.shape[-2:])(
320+
(data_a, (data_h, data_v, data_d))
321+
)
322+
self[node] = rec
323+
else:
324+
data_a = self[node + "a"].unsqueeze(1)
325+
data_h = self[node + "h"].unsqueeze(1)
326+
data_v = self[node + "v"].unsqueeze(1)
327+
data_d = self[node + "d"].unsqueeze(1)
328+
rec = self._get_waverec(data_a.shape[-2:])(
329+
(data_a, (data_h, data_v, data_d))
330+
)
331+
self[node] = rec.squeeze(1)
332+
return self
333+
334+
def get_natural_order(self, level: int) -> List[str]:
335+
"""Get the natural ordering for a given decomposition level.
336+
337+
Args:
338+
level (int): The decomposition level.
339+
340+
Returns:
341+
list: A list with the filter order strings.
342+
"""
343+
return ["".join(p) for p in list(product(["a", "h", "v", "d"], repeat=level))]
344+
229345
def _get_wavedec(
230346
self, shape: Tuple[int, ...]
231347
) -> Callable[
@@ -246,8 +362,27 @@ def _get_wavedec(
246362
else:
247363
return partial(wavedec2, wavelet=self.wavelet, level=1, mode=self.mode)
248364

365+
def _get_waverec(
366+
self, shape: Tuple[int, ...]
367+
) -> Callable[
368+
[Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]],
369+
torch.Tensor,
370+
]:
371+
if self.mode == "boundary":
372+
shape = tuple(shape)
373+
if shape not in self.matrix_waverec2_dict.keys():
374+
self.matrix_waverec2_dict[shape] = MatrixWaverec2(
375+
self.wavelet,
376+
boundary=self.boundary,
377+
separable=self.separable,
378+
)
379+
fun = self.matrix_waverec2_dict[shape]
380+
return fun # type: ignore
381+
else:
382+
return partial(waverec2, wavelet=self.wavelet)
383+
249384
def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None:
250-
if not self.max_level:
385+
if not self.maxlevel:
251386
raise AssertionError
252387

253388
# TODO: This is a workaround since the convolutional transforms insert a
@@ -256,7 +391,7 @@ def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None:
256391
data = data.squeeze(1)
257392

258393
self.data[path] = data
259-
if level < self.max_level:
394+
if level < self.maxlevel:
260395
result_a, (result_h, result_v, result_d) = self._get_wavedec(
261396
data.shape[-2:]
262397
)(data)
@@ -281,16 +416,16 @@ def __getitem__(self, key: str) -> torch.Tensor:
281416
ValueError: If the wavelet packet tree is not initialized.
282417
KeyError: If no wavelet coefficients are indexed by the specified key.
283418
"""
284-
if self.max_level is None:
419+
if self.maxlevel is None:
285420
raise ValueError(
286421
"The wavelet packet tree must be initialized via 'transform' before "
287422
"its values can be accessed!"
288423
)
289-
if key not in self and len(key) > self.max_level:
424+
if key not in self and len(key) > self.maxlevel:
290425
raise KeyError(
291426
f"The requested level {len(key)} with key '{key}' is too large and "
292427
"cannot be accessed! This wavelet packet tree is initialized with "
293-
f"maximum level {self.max_level}."
428+
f"maximum level {self.maxlevel}."
294429
)
295430
return super().__getitem__(key)
296431

0 commit comments

Comments
 (0)