Skip to content

Commit fcc4845

Browse files
committed
Merge branch 'v0.1.5' of github.com:v0lta/PyTorch-Wavelet-Toolbox into v0.1.5
2 parents ea85810 + 16a74f9 commit fcc4845

File tree

8 files changed

+141
-92
lines changed

8 files changed

+141
-92
lines changed

README.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ Welcome to the PyTorch wavelet toolbox. This package implements:
4545
- single and two-dimensional wavelet packet forward and backward transforms are available via the ``WaveletPacket`` and ``WaveletPacket2D`` objects,
4646
- finally, this package provides adaptive wavelet support (experimental).
4747

48-
This toolbox supports pywt-wavelets. Complete documentation is available:
49-
https://pytorch-wavelet-toolbox.readthedocs.io/
48+
This toolbox extends `PyWavelets <https://pywavelets.readthedocs.io/en/latest/>`_ . We additionally provide GPU and gradient support via a PyTorch backend.
49+
Complete documentation is available at: https://pytorch-wavelet-toolbox.readthedocs.io/
5050

5151

5252
**Installation**
@@ -149,8 +149,8 @@ Reconsidering the 1d case, try:
149149
150150
151151
The process for the 2d transforms ``MatrixWavedec2``, ``MatrixWaverec2`` works similarly.
152-
By default, a non-separable transformation is used.
153-
To use a separable transformation, pass ``separable=True`` to ``MatrixWavedec2`` and ``MatrixWaverec2``.
152+
By default, a separable transformation is used.
153+
To use a non-separable transformation, pass ``separable=False`` to ``MatrixWavedec2`` and ``MatrixWaverec2``.
154154
Separable transformations use a 1d transformation along both axes, which might be faster since fewer matrix entries
155155
have to be orthogonalized.
156156

examples/deepfake_analysis/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Recreating these experiments requires roughly 8GB of free disc space and 10GB of
2525
To reproduce these plots:
2626
1. Download [ffhq_style_gan.zip](https://drive.google.com/uc?id=1MOHKuEVqURfCKAN9dwp1o2tuR19OTQCF&export=download) and
2727
2. Extract the image pairs here.
28-
3. Check the file structure. In the `ffhq_style_gan` the folder structure should be:
28+
3. Check the file structure. In `ffhq_style_gan` the folder structure should be:
2929
```
3030
source_data
3131
├── A_ffhq

src/ptwt/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,4 @@
88
from .matmul_transform_2 import MatrixWavedec2, MatrixWaverec2
99
from .matmul_transform_3 import MatrixWavedec3, MatrixWaverec3
1010
from .packets import WaveletPacket, WaveletPacket2D
11-
from .separable_conv_transform import (
12-
fswavedec2,
13-
fswavedec3,
14-
fswaverec2,
15-
fswaverec3,
16-
)
11+
from .separable_conv_transform import fswavedec2, fswavedec3, fswaverec2, fswaverec3

src/ptwt/conv_transform.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,20 @@ def _flatten_2d_coeff_lst(
186186
return flat_coeff_lst
187187

188188

189+
def _adjust_padding_at_reconstruction(
190+
res_ll_size: int, coeff_size: int, pad_end: int, pad_start: int
191+
) -> Tuple[int, int]:
192+
pred_size = res_ll_size - (pad_start + pad_end)
193+
next_size = coeff_size
194+
if next_size == pred_size:
195+
pass
196+
elif next_size == pred_size - 1:
197+
pad_end += 1
198+
else:
199+
raise AssertionError("padding error, please open an issue on github")
200+
return pad_end, pad_start
201+
202+
189203
def wavedec(
190204
data: torch.Tensor,
191205
wavelet: Union[Wavelet, str],
@@ -318,14 +332,9 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T
318332
padl = (2 * filt_len - 3) // 2
319333
padr = (2 * filt_len - 3) // 2
320334
if c_pos < len(coeffs) - 2:
321-
pred_len = res_lo.shape[-1] - (padl + padr)
322-
next_len = coeffs[c_pos + 2].shape[-1]
323-
if next_len != pred_len:
324-
padr += 1
325-
pred_len = res_lo.shape[-1] - (padl + padr)
326-
assert (
327-
next_len == pred_len
328-
), "padding error, please open an issue on github "
335+
padr, padl = _adjust_padding_at_reconstruction(
336+
res_lo.shape[-1], coeffs[c_pos + 2].shape[-1], padr, padl
337+
)
329338
if padl > 0:
330339
res_lo = res_lo[..., padl:]
331340
if padr > 0:

src/ptwt/conv_transform_2.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
import torch
1212

1313
from ._util import Wavelet, _as_wavelet, _get_len, _is_dtype_supported, _outer
14-
from .conv_transform import _get_pad, _translate_boundary_strings, get_filter_tensors
14+
from .conv_transform import (
15+
_adjust_padding_at_reconstruction,
16+
_get_pad,
17+
_translate_boundary_strings,
18+
get_filter_tensors,
19+
)
1520

1621

1722
def construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
@@ -244,22 +249,13 @@ def waverec2(
244249
padt = (2 * filt_len - 3) // 2
245250
padb = (2 * filt_len - 3) // 2
246251
if c_pos < len(coeffs) - 2:
247-
pred_len = res_ll.shape[-1] - (padl + padr)
248-
next_len = coeffs[c_pos + 2][0].shape[-1]
249-
pred_len2 = res_ll.shape[-2] - (padt + padb)
250-
next_len2 = coeffs[c_pos + 2][0].shape[-2]
251-
if next_len != pred_len:
252-
padr += 1
253-
pred_len = res_ll.shape[-1] - (padl + padr)
254-
assert (
255-
next_len == pred_len
256-
), "padding error, please open an issue on github "
257-
if next_len2 != pred_len2:
258-
padb += 1
259-
pred_len2 = res_ll.shape[-2] - (padt + padb)
260-
assert (
261-
next_len2 == pred_len2
262-
), "padding error, please open an issue on github "
252+
padr, padl = _adjust_padding_at_reconstruction(
253+
res_ll.shape[-1], coeffs[c_pos + 2][0].shape[-1], padr, padl
254+
)
255+
padb, padt = _adjust_padding_at_reconstruction(
256+
res_ll.shape[-2], coeffs[c_pos + 2][0].shape[-2], padb, padt
257+
)
258+
263259
if padt > 0:
264260
res_ll = res_ll[..., padt:, :]
265261
if padb > 0:

src/ptwt/conv_transform_3.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@
33
The functions here are based on torch.nn.functional.conv3d and it's transpose.
44
"""
55

6-
from typing import Dict, List, Optional, Union
6+
from typing import Dict, List, Optional, Sequence, Union, cast
77

88
import pywt
99
import torch
1010

1111
from ._util import Wavelet, _as_wavelet, _get_len, _is_dtype_supported, _outer
12-
from .conv_transform import _get_pad, _translate_boundary_strings, get_filter_tensors
12+
from .conv_transform import (
13+
_adjust_padding_at_reconstruction,
14+
_get_pad,
15+
_translate_boundary_strings,
16+
get_filter_tensors,
17+
)
1318

1419

1520
def _construct_3d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
@@ -153,7 +158,7 @@ def wavedec3(
153158

154159

155160
def waverec3(
156-
coeffs: List[Union[torch.Tensor, Dict[str, torch.Tensor]]],
161+
coeffs: Sequence[Union[torch.Tensor, Dict[str, torch.Tensor]]],
157162
wavelet: Union[Wavelet, str],
158163
) -> torch.Tensor:
159164
"""Reconstruct a signal from wavelet coefficients.
@@ -199,7 +204,8 @@ def waverec3(
199204
filt_len = rec_lo.shape[-1]
200205
rec_filt = _construct_3d_filt(lo=rec_lo, hi=rec_hi)
201206

202-
for c_pos, coeff_dict in enumerate(coeffs[1:]):
207+
coeff_dicts = cast(Sequence[Dict[str, torch.Tensor]], coeffs[1:])
208+
for c_pos, coeff_dict in enumerate(coeff_dicts):
203209
if not isinstance(coeff_dict, dict) or len(coeff_dict) != 7:
204210
raise ValueError(
205211
f"Unexpected detail coefficient type: {type(coeff_dict)}. Detail "
@@ -238,31 +244,16 @@ def waverec3(
238244
padr = (2 * filt_len - 3) // 2
239245
padt = (2 * filt_len - 3) // 2
240246
padb = (2 * filt_len - 3) // 2
241-
if c_pos < len(coeffs) - 2:
242-
pred_len = res_lll.shape[-1] - (padl + padr)
243-
next_len = coeffs[c_pos + 2]["aad"].shape[-1] # type: ignore
244-
pred_len2 = res_lll.shape[-2] - (padt + padb)
245-
next_len2 = coeffs[c_pos + 2]["aad"].shape[-2] # type: ignore
246-
pred_len3 = res_lll.shape[-3] - (padfr + padba)
247-
next_len3 = coeffs[c_pos + 2]["aad"].shape[-3] # type: ignore
248-
if next_len != pred_len:
249-
padr += 1
250-
pred_len = res_lll.shape[-1] - (padl + padr)
251-
assert (
252-
next_len == pred_len
253-
), "padding error, please open an issue on github "
254-
if next_len2 != pred_len2:
255-
padb += 1
256-
pred_len2 = res_lll.shape[-2] - (padt + padb)
257-
assert (
258-
next_len2 == pred_len2
259-
), "padding error, please open an issue on github "
260-
if next_len3 != pred_len3:
261-
padba += 1
262-
pred_len3 = res_lll.shape[-3] - (padba + padfr)
263-
assert (
264-
next_len3 == pred_len3
265-
), "padding error, please open an issue on github "
247+
if c_pos + 1 < len(coeff_dicts):
248+
padr, padl = _adjust_padding_at_reconstruction(
249+
res_lll.shape[-1], coeff_dicts[c_pos + 1]["aad"].shape[-1], padr, padl
250+
)
251+
padb, padt = _adjust_padding_at_reconstruction(
252+
res_lll.shape[-2], coeff_dicts[c_pos + 1]["aad"].shape[-2], padb, padt
253+
)
254+
padba, padfr = _adjust_padding_at_reconstruction(
255+
res_lll.shape[-3], coeff_dicts[c_pos + 1]["aad"].shape[-3], padba, padfr
256+
)
266257
if padt > 0:
267258
res_lll = res_lll[..., padt:, :]
268259
if padb > 0:

src/ptwt/packets.py

Lines changed: 83 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import collections
55
from functools import partial
66
from itertools import product
7-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
7+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union, cast
88

99
import pywt
1010
import torch
@@ -14,6 +14,7 @@
1414
from .conv_transform_2 import wavedec2, waverec2
1515
from .matmul_transform import MatrixWavedec, MatrixWaverec
1616
from .matmul_transform_2 import MatrixWavedec2, MatrixWaverec2
17+
from .separable_conv_transform import fswavedec2, fswaverec2
1718

1819
if TYPE_CHECKING:
1920
BaseDict = collections.UserDict[str, torch.Tensor]
@@ -129,6 +130,12 @@ def reconstruct(self) -> "WaveletPacket":
129130
data_a = self[node + "a"]
130131
data_b = self[node + "d"]
131132
rec = self._get_waverec(data_a.shape[-1])([data_a, data_b])
133+
if level > 0:
134+
if rec.shape[-1] != self[node].shape[-1]:
135+
assert (
136+
rec.shape[-1] == self[node].shape[-1] + 1
137+
), "padding error, please open an issue on github"
138+
rec = rec[..., :-1]
132139
self[node] = rec
133140
return self
134141

@@ -249,9 +256,8 @@ def __init__(
249256
to use in the sparse matrix backend. Only used if `mode`
250257
equals 'boundary'. Choose from 'qr' or 'gramschmidt'.
251258
Defaults to 'qr'.
252-
separable (bool): If true and the sparse matrix backend is selected,
253-
a separable transform is performed, i.e. each image axis is
254-
transformed separately. Defaults to False.
259+
separable (bool): If true, a separable transform is performed,
260+
i.e. each image axis is transformed separately. Defaults to False.
255261
maxlevel (int, optional): Value is passed on to `transform`.
256262
The highest decomposition level to compute. If None, the maximum level
257263
is determined from the input data shape. Defaults to None.
@@ -288,6 +294,10 @@ def transform(
288294
maxlevel = pywt.dwt_max_level(min(data.shape[-2:]), self.wavelet.dec_len)
289295
self.maxlevel = maxlevel
290296

297+
if data.dim() == 2:
298+
# add batch dim to unbatched input
299+
data = data.unsqueeze(0)
300+
291301
self._recursive_dwt2d(data, level=0, path="")
292302
return self
293303

@@ -306,24 +316,25 @@ def reconstruct(self) -> "WaveletPacket2D":
306316

307317
for level in reversed(range(self.maxlevel)):
308318
for node in self.get_natural_order(level):
309-
if self.mode == "boundary":
310-
data_a = self[node + "a"]
311-
data_h = self[node + "h"]
312-
data_v = self[node + "v"]
313-
data_d = self[node + "d"]
314-
rec = self._get_waverec(data_a.shape[-2:])(
315-
(data_a, (data_h, data_v, data_d))
316-
)
317-
self[node] = rec
318-
else:
319-
data_a = self[node + "a"]
320-
data_h = self[node + "h"]
321-
data_v = self[node + "v"]
322-
data_d = self[node + "d"]
323-
rec = self._get_waverec(data_a.shape[-2:])(
324-
(data_a, (data_h, data_v, data_d))
325-
)
326-
self[node] = rec.squeeze(1)
319+
data_a = self[node + "a"]
320+
data_h = self[node + "h"]
321+
data_v = self[node + "v"]
322+
data_d = self[node + "d"]
323+
rec = self._get_waverec(data_a.shape[-2:])(
324+
[data_a, (data_h, data_v, data_d)]
325+
)
326+
if level > 0:
327+
if rec.shape[-1] != self[node].shape[-1]:
328+
assert (
329+
rec.shape[-1] == self[node].shape[-1] + 1
330+
), "padding error, please open an issue on github"
331+
rec = rec[..., :-1]
332+
if rec.shape[-2] != self[node].shape[-2]:
333+
assert (
334+
rec.shape[-2] == self[node].shape[-2] + 1
335+
), "padding error, please open an issue on github"
336+
rec = rec[..., :-1, :]
337+
self[node] = rec
327338
return self
328339

329340
def get_natural_order(self, level: int) -> List[str]:
@@ -354,13 +365,17 @@ def _get_wavedec(
354365
)
355366
fun = self.matrix_wavedec2_dict[shape]
356367
return fun
368+
elif self.separable:
369+
return self._transform_fsdict_to_tuple_func(
370+
partial(fswavedec2, wavelet=self.wavelet, level=1, mode=self.mode)
371+
)
357372
else:
358373
return partial(wavedec2, wavelet=self.wavelet, level=1, mode=self.mode)
359374

360375
def _get_waverec(
361376
self, shape: Tuple[int, ...]
362377
) -> Callable[
363-
[Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]],
378+
[List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]],
364379
torch.Tensor,
365380
]:
366381
if self.mode == "boundary":
@@ -371,11 +386,54 @@ def _get_waverec(
371386
boundary=self.boundary,
372387
separable=self.separable,
373388
)
374-
fun = self.matrix_waverec2_dict[shape]
375-
return fun # type: ignore
389+
return self.matrix_waverec2_dict[shape]
390+
elif self.separable:
391+
return self._transform_tuple_to_fsdict_func(
392+
partial(fswaverec2, wavelet=self.wavelet)
393+
)
376394
else:
377395
return partial(waverec2, wavelet=self.wavelet)
378396

397+
def _transform_fsdict_to_tuple_func(
398+
self,
399+
fs_dict_func: Callable[
400+
[torch.Tensor], List[Union[torch.Tensor, Dict[str, torch.Tensor]]]
401+
],
402+
) -> Callable[
403+
[torch.Tensor],
404+
List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]],
405+
]:
406+
def _tuple_func(
407+
data: torch.Tensor,
408+
) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
409+
a_coeff, fsdict = fs_dict_func(data)
410+
fsdict = cast(Dict[str, torch.Tensor], fsdict)
411+
return [
412+
cast(torch.Tensor, a_coeff),
413+
(fsdict["ad"], fsdict["da"], fsdict["dd"]),
414+
]
415+
416+
return _tuple_func
417+
418+
def _transform_tuple_to_fsdict_func(
419+
self,
420+
fsdict_func: Callable[
421+
[List[Union[torch.Tensor, Dict[str, torch.Tensor]]]], torch.Tensor
422+
],
423+
) -> Callable[
424+
[List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]],
425+
torch.Tensor,
426+
]:
427+
def _fsdict_func(
428+
coeffs: List[
429+
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
430+
]
431+
) -> torch.Tensor:
432+
a, (h, v, d) = coeffs
433+
return fsdict_func([cast(torch.Tensor, a), {"ad": h, "da": v, "dd": d}])
434+
435+
return _fsdict_func
436+
379437
def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None:
380438
if not self.maxlevel:
381439
raise AssertionError

tests/test_separable_conv_fwt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from src.ptwt.matmul_transform_2 import MatrixWavedec2
1010
from src.ptwt.matmul_transform_3 import MatrixWavedec3
1111
from src.ptwt.separable_conv_transform import (
12+
_fswavedec,
13+
_fswaverec,
1214
_separable_conv_wavedecn,
1315
_separable_conv_waverecn,
14-
_fswavedec,
1516
fswavedec2,
1617
fswavedec3,
17-
_fswaverec,
1818
fswaverec2,
1919
fswaverec3,
2020
)

0 commit comments

Comments
 (0)