Skip to content

Commit f775b44

Browse files
committed
test batched inverse transforms, return tree.
1 parent 9aef3d5 commit f775b44

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Welcome to the PyTorch wavelet toolbox. This package implements:
3737
- ``MatrixWavedec`` and ``MatrixWaverec`` provide sparse-matrix-based fast wavelet transforms with boundary filters,
3838
- 2d sparse-matrix transforms with separable & non-separable boundary filters are available (experimental),
3939
- ``cwt`` computes a one-dimensional continuous forward transform,
40-
- single and two-dimensional wavelet packet forward transforms are available via the ``WaveletPacket`` and ``WaveletPacket2D`` objects,
40+
- single and two-dimensional wavelet packet forward and backward transforms are available via the ``WaveletPacket`` and ``WaveletPacket2D`` objects,
4141
- finally, this package provides adaptive wavelet support (experimental).
4242

4343
This toolbox supports pywt-wavelets. Complete documentation is available:

src/ptwt/packets.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def transform(
101101
self._recursive_dwt(data, level=0, path="")
102102
return self
103103

104-
def reconstruct(self) -> None:
104+
def reconstruct(self) -> "WaveletPacket":
105105
"""Recursively reconstruct the input starting from the leaf nodes.
106106
107107
Reconstruction replaces the input-data originally assigned to this object.
@@ -130,6 +130,7 @@ def reconstruct(self) -> None:
130130
data_b = self[node + "d"]
131131
rec = self._get_waverec(data_a.shape[-1])([data_a, data_b])
132132
self[node] = rec
133+
return self
133134

134135
def _get_wavedec(
135136
self,
@@ -295,7 +296,7 @@ def transform(
295296
self._recursive_dwt2d(data, level=0, path="")
296297
return self
297298

298-
def reconstruct(self) -> None:
299+
def reconstruct(self) -> "WaveletPacket2D":
299300
"""Recursively reconstruct the input starting from the leaf nodes.
300301
301302
Note:
@@ -328,6 +329,7 @@ def reconstruct(self) -> None:
328329
(data_a, (data_h, data_v, data_d))
329330
)
330331
self[node] = rec.squeeze(1)
332+
return self
331333

332334
def get_natural_order(self, level: int) -> List[str]:
333335
"""Get the natural ordering for a given decomposition level.

tests/test_packets.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,18 @@ def _compare_trees1(
2626

2727
if transform_mode:
2828
twp = WaveletPacket(
29-
None, wavelet, mode=ptwt_boundary, max_level=max_lev
30-
).transform(torch.from_numpy(data), max_level=max_lev)
29+
None, wavelet, mode=ptwt_boundary, maxlevel=max_lev
30+
).transform(torch.from_numpy(data), maxlevel=max_lev)
3131
else:
3232
twp = WaveletPacket(
33-
torch.from_numpy(data), wavelet, mode=ptwt_boundary, max_level=max_lev
33+
torch.from_numpy(data), wavelet, mode=ptwt_boundary, maxlevel=max_lev
3434
)
3535

3636
# if multiple_transform flag is set, recalculcate the packets
3737
if multiple_transforms:
38-
twp.transform(torch.from_numpy(data), max_level=max_lev)
38+
twp.transform(torch.from_numpy(data), maxlevel=max_lev)
3939

40-
nodes = twp.get_level(twp.max_level)
40+
nodes = twp.get_level(twp.maxlevel)
4141
twp_lst = []
4242
for node in nodes:
4343
twp_lst.append(twp[node])
@@ -58,7 +58,7 @@ def _compare_trees1(
5858
np_res = np.concatenate(np_lst, -1)
5959
np_batches.append(np_res)
6060
np_batches = np.stack(np_batches, 0)
61-
assert wp.maxlevel == twp.max_level
61+
assert wp.maxlevel == twp.maxlevel
6262
assert np.allclose(torch_res, np_batches)
6363

6464

@@ -101,22 +101,22 @@ def _compare_trees2(
101101
if transform_mode:
102102
ptwt_wp_tree = WaveletPacket2D(
103103
None, wavelet=wavelet, mode=ptwt_boundary
104-
).transform(pt_data, max_level=max_lev)
104+
).transform(pt_data, maxlevel=max_lev)
105105
else:
106106
ptwt_wp_tree = WaveletPacket2D(
107-
pt_data, wavelet=wavelet, mode=ptwt_boundary, max_level=max_lev
107+
pt_data, wavelet=wavelet, mode=ptwt_boundary, maxlevel=max_lev
108108
)
109109

110110
# if multiple_transform flag is set, recalculcate the packets
111111
if multiple_transforms:
112-
ptwt_wp_tree.transform(pt_data, max_level=max_lev)
112+
ptwt_wp_tree.transform(pt_data, maxlevel=max_lev)
113113

114114
packets = []
115115
for node in wp_keys:
116116
packet = ptwt_wp_tree["".join(node)]
117117
packets.append(packet)
118118
packets_pt = torch.stack(packets, 1).numpy()
119-
assert wp_tree.maxlevel == ptwt_wp_tree.max_level
119+
assert wp_tree.maxlevel == ptwt_wp_tree.maxlevel
120120
assert np.allclose(packets_pt, batch_np_packets)
121121

122122

@@ -292,11 +292,12 @@ def test_access_errors_2d():
292292

293293
@pytest.mark.parametrize("level", [1, 2, 3])
294294
@pytest.mark.parametrize("base_key", ["a", "d"])
295-
@pytest.mark.parametrize("length", [64, 128])
295+
@pytest.mark.parametrize("length", [63, 64, 128])
296+
@pytest.mark.parametrize("batch_size", [1, 2])
296297
@pytest.mark.parametrize("wavelet", ["db1", "db2", "sym4"])
297-
def test_inverse_packet_1d(level, base_key, length, wavelet):
298+
def test_inverse_packet_1d(level, base_key, length, batch_size, wavelet):
298299
"""Test the 1d reconstruction code."""
299-
signal = np.random.randn(1, length)
300+
signal = np.random.randn(batch_size, length)
300301
mode = "reflect"
301302
wp = pywt.WaveletPacket(signal, wavelet, mode=mode, maxlevel=level)
302303
ptwp = WaveletPacket(torch.from_numpy(signal), wavelet, mode=mode, maxlevel=level)
@@ -309,22 +310,22 @@ def test_inverse_packet_1d(level, base_key, length, wavelet):
309310

310311
@pytest.mark.parametrize("level", [1, 3])
311312
@pytest.mark.parametrize("base_key", ["a", "h", "d"])
312-
@pytest.mark.parametrize("size", [(32, 32), (32, 64)])
313+
@pytest.mark.parametrize("size", [(1, 32, 32), (2, 31, 64)])
313314
@pytest.mark.parametrize("wavelet", ["db1", "db2", "sym4"])
314315
def test_inverse_packet_2d(level, base_key, size, wavelet):
315316
"""Test the 2d reconstruction code."""
316-
signal = np.random.randn(1, size[0], size[1])
317+
signal = np.random.randn(size[0], size[1], size[2])
317318
mode = "reflect"
318319
wp = pywt.WaveletPacket2D(signal, wavelet, mode=mode, maxlevel=level)
319320
ptwp = WaveletPacket2D(torch.from_numpy(signal), wavelet, mode=mode, maxlevel=level)
320321
wp[base_key * level].data *= 0
321322
ptwp[base_key * level].data *= 0
322323
wp.reconstruct(update=True)
323324
ptwp.reconstruct()
324-
assert np.allclose(wp[""].data, ptwp[""].numpy()[:, : size[0], : size[1]])
325+
assert np.allclose(wp[""].data, ptwp[""].numpy()[:, : size[1], : size[2]])
325326

326327

327-
def test_boundary_packet_1d():
328+
def test_inverse_boundary_packet_1d():
328329
"""Test the 2d boundary reconstruction code."""
329330
signal = np.random.randn(1, 16)
330331
wp = pywt.WaveletPacket(signal, "haar", mode="zero", maxlevel=2)
@@ -336,7 +337,7 @@ def test_boundary_packet_1d():
336337
assert np.allclose(wp[""].data, ptwp[""].numpy()[:, :16])
337338

338339

339-
def test_boundary_packet_2d():
340+
def test_inverse_boundary_packet_2d():
340341
"""Test the 2d boundary reconstruction code."""
341342
size = (16, 16)
342343
level = 2

0 commit comments

Comments
 (0)