Skip to content

Commit 0ac20c5

Browse files
committed
Refactor of order methods for WaveletPacket2d
Make the functions creating the natural and frequency packet order for WaveletPacket2d static methods of WaveletPacket2d. This changes * `get_natural_order` from instance to static function * `get_frequency_order` from separate func to static function in the scope of WaveletPacket2d Further, to make both methods consistent, `get_frequency_order` now returns concatenated strings instead of a tuple of single char strings.
1 parent b8a510a commit 0ac20c5

File tree

3 files changed

+71
-107
lines changed

3 files changed

+71
-107
lines changed

examples/deepfake_analysis/packet_plot.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,6 @@
1111
import ptwt
1212

1313

14-
def get_freq_order(level: int):
15-
"""Get the frequency order for a given packet decomposition level.
16-
Adapted from:
17-
https://github.com/PyWavelets/pywt/blob/master/pywt/_wavelet_packets.py
18-
The code elements denote the filter application order. The filters
19-
are named following the pywt convention as:
20-
a - LL, low-low coefficients
21-
h - LH, low-high coefficients
22-
v - HL, high-low coefficients
23-
d - HH, high-high coefficients
24-
"""
25-
wp_natural_path = list(product(["a", "h", "v", "d"], repeat=level))
26-
27-
def _get_graycode_order(level, x="a", y="d"):
28-
graycode_order = [x, y]
29-
for _ in range(level - 1):
30-
graycode_order = [x + path for path in graycode_order] + [
31-
y + path for path in graycode_order[::-1]
32-
]
33-
return graycode_order
34-
35-
def _expand_2d_path(path):
36-
expanded_paths = {"d": "hh", "h": "hl", "v": "lh", "a": "ll"}
37-
return (
38-
"".join([expanded_paths[p][0] for p in path]),
39-
"".join([expanded_paths[p][1] for p in path]),
40-
)
41-
42-
nodes: dict = {}
43-
for (row_path, col_path), node in [
44-
(_expand_2d_path(node), node) for node in wp_natural_path
45-
]:
46-
nodes.setdefault(row_path, {})[col_path] = node
47-
graycode_order = _get_graycode_order(level, x="l", y="h")
48-
nodes_list: list = [nodes[path] for path in graycode_order if path in nodes]
49-
wp_frequency_path = []
50-
for row in nodes_list:
51-
wp_frequency_path.append([row[path] for path in graycode_order if path in row])
52-
return wp_frequency_path, wp_natural_path
53-
54-
5514
def generate_frequency_packet_image(packet_array: np.ndarray, degree: int):
5615
"""Create a ready-to-polt image with frequency-order packages.
5716
Given a packet array in natural order, creat an image which is
@@ -63,7 +22,8 @@ def generate_frequency_packet_image(packet_array: np.ndarray, degree: int):
6322
Returns:
6423
[np.ndarray]: The image of shape [original_height, original_width]
6524
"""
66-
wp_freq_path, wp_natural_path = get_freq_order(degree)
25+
wp_freq_path = ptwt.WaveletPacket2D.get_freq_order(degree)
26+
wp_natural_path = ptwt.WaveletPacket2D.get_natural_order(degree)
6727

6828
image = []
6929
# go through the rows.
@@ -107,7 +67,8 @@ def load_images(path: str) -> list:
10767

10868

10969
if __name__ == "__main__":
110-
frequency_path, natural_path = get_freq_order(level=3)
70+
freq_path = ptwt.WaveletPacket2D.get_freq_order(level=3)
71+
frequency_path = ptwt.WaveletPacket2D.get_natural_order(level=3)
11172
print("Loading ffhq images:")
11273
ffhq_images = load_images("./ffhq_style_gan/source_data/A_ffhq")
11374
print("processing ffhq")

src/ptwt/packets.py

Lines changed: 65 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(
122122

123123
def transform(
124124
self, data: torch.Tensor, maxlevel: Optional[int] = None
125-
) -> "WaveletPacket":
125+
) -> WaveletPacket:
126126
"""Calculate the 1d wavelet packet transform for the input data.
127127
128128
Args:
@@ -139,7 +139,7 @@ def transform(
139139
self._recursive_dwt(data, level=0, path="")
140140
return self
141141

142-
def reconstruct(self) -> "WaveletPacket":
142+
def reconstruct(self) -> WaveletPacket:
143143
"""Recursively reconstruct the input starting from the leaf nodes.
144144
145145
Reconstruction replaces the input data originally assigned to this object.
@@ -326,7 +326,7 @@ def __init__(
326326

327327
def transform(
328328
self, data: torch.Tensor, maxlevel: Optional[int] = None
329-
) -> "WaveletPacket2D":
329+
) -> WaveletPacket2D:
330330
"""Calculate the 2d wavelet packet transform for the input data.
331331
332332
The transform function allows reusing the same object.
@@ -350,7 +350,7 @@ def transform(
350350
self._recursive_dwt2d(data, level=0, path="")
351351
return self
352352

353-
def reconstruct(self) -> "WaveletPacket2D":
353+
def reconstruct(self) -> WaveletPacket2D:
354354
"""Recursively reconstruct the input starting from the leaf nodes.
355355
356356
Note:
@@ -364,7 +364,7 @@ def reconstruct(self) -> "WaveletPacket2D":
364364
)
365365

366366
for level in reversed(range(self.maxlevel)):
367-
for node in self.get_natural_order(level):
367+
for node in WaveletPacket2D.get_natural_order(level):
368368
data_a = self[node + "a"]
369369
data_h = self[node + "h"]
370370
data_v = self[node + "v"]
@@ -386,17 +386,6 @@ def reconstruct(self) -> "WaveletPacket2D":
386386
self[node] = rec
387387
return self
388388

389-
def get_natural_order(self, level: int) -> list[str]:
390-
"""Get the natural ordering for a given decomposition level.
391-
392-
Args:
393-
level (int): The decomposition level.
394-
395-
Returns:
396-
A list with the filter order strings.
397-
"""
398-
return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)]
399-
400389
def _get_wavedec(self, shape: tuple[int, ...]) -> Callable[
401390
[torch.Tensor],
402391
WaveletCoeff2d,
@@ -525,54 +514,68 @@ def __getitem__(self, key: str) -> torch.Tensor:
525514
)
526515
return super().__getitem__(key)
527516

517+
@staticmethod
518+
def get_natural_order(level: int) -> list[str]:
519+
"""Get the natural ordering for a given decomposition level.
528520
529-
def get_freq_order(level: int) -> list[list[tuple[str, ...]]]:
530-
"""Get the frequency order for a given packet decomposition level.
521+
Args:
522+
level (int): The decomposition level.
531523
532-
Use this code to create two-dimensional frequency orderings.
524+
Returns:
525+
A list with the filter order strings.
526+
"""
527+
return ["".join(p) for p in product(["a", "h", "v", "d"], repeat=level)]
533528

534-
Args:
535-
level (int): The number of decomposition scales.
529+
@staticmethod
530+
def get_freq_order(level: int) -> list[list[str]]:
531+
"""Get the frequency order for a given packet decomposition level.
536532
537-
Returns:
538-
A list with the tree nodes in frequency order.
539-
540-
Note:
541-
Adapted from:
542-
https://github.com/PyWavelets/pywt/blob/master/pywt/_wavelet_packets.py
543-
544-
The code elements denote the filter application order. The filters
545-
are named following the pywt convention as:
546-
a - LL, low-low coefficients
547-
h - LH, low-high coefficients
548-
v - HL, high-low coefficients
549-
d - HH, high-high coefficients
550-
"""
551-
wp_natural_path = product(["a", "h", "v", "d"], repeat=level)
533+
Use this code to create two-dimensional frequency orderings.
552534
553-
def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]:
554-
graycode_order = [x, y]
555-
for _ in range(level - 1):
556-
graycode_order = [x + path for path in graycode_order] + [
557-
y + path for path in graycode_order[::-1]
558-
]
559-
return graycode_order
560-
561-
def _expand_2d_path(path: tuple[str, ...]) -> tuple[str, str]:
562-
expanded_paths = {"d": "hh", "h": "hl", "v": "lh", "a": "ll"}
563-
return (
564-
"".join([expanded_paths[p][0] for p in path]),
565-
"".join([expanded_paths[p][1] for p in path]),
566-
)
567-
568-
nodes_dict: dict[str, dict[str, tuple[str, ...]]] = {}
569-
for (row_path, col_path), node in [
570-
(_expand_2d_path(node), node) for node in wp_natural_path
571-
]:
572-
nodes_dict.setdefault(row_path, {})[col_path] = node
573-
graycode_order = _get_graycode_order(level, x="l", y="h")
574-
nodes = [nodes_dict[path] for path in graycode_order if path in nodes_dict]
575-
result = []
576-
for row in nodes:
577-
result.append([row[path] for path in graycode_order if path in row])
578-
return result
535+
Args:
536+
level (int): The number of decomposition scales.
537+
538+
Returns:
539+
A list with the tree nodes in frequency order.
540+
541+
Note:
542+
Adapted from:
543+
https://github.com/PyWavelets/pywt/blob/master/pywt/_wavelet_packets.py
544+
545+
The code elements denote the filter application order. The filters
546+
are named following the pywt convention as:
547+
a - LL, low-low coefficients
548+
h - LH, low-high coefficients
549+
v - HL, high-low coefficients
550+
d - HH, high-high coefficients
551+
"""
552+
wp_natural_path = product(["a", "h", "v", "d"], repeat=level)
553+
554+
def _get_graycode_order(level: int, x: str = "a", y: str = "d") -> list[str]:
555+
graycode_order = [x, y]
556+
for _ in range(level - 1):
557+
graycode_order = [x + path for path in graycode_order] + [
558+
y + path for path in graycode_order[::-1]
559+
]
560+
return graycode_order
561+
562+
def _expand_2d_path(path: tuple[str, ...]) -> tuple[str, str]:
563+
expanded_paths = {"d": "hh", "h": "hl", "v": "lh", "a": "ll"}
564+
return (
565+
"".join([expanded_paths[p][0] for p in path]),
566+
"".join([expanded_paths[p][1] for p in path]),
567+
)
568+
569+
nodes_dict: dict[str, dict[str, tuple[str, ...]]] = {}
570+
for (row_path, col_path), node in [
571+
(_expand_2d_path(node), node) for node in wp_natural_path
572+
]:
573+
nodes_dict.setdefault(row_path, {})[col_path] = node
574+
graycode_order = _get_graycode_order(level, x="l", y="h")
575+
nodes = [nodes_dict[path] for path in graycode_order if path in nodes_dict]
576+
result = []
577+
for row in nodes:
578+
result.append(
579+
["".join(row[path]) for path in graycode_order if path in row]
580+
)
581+
return result

tests/test_packets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from scipy import datasets
1313

1414
from ptwt.constants import ExtendedBoundaryMode
15-
from ptwt.packets import WaveletPacket, WaveletPacket2D, get_freq_order
15+
from ptwt.packets import WaveletPacket, WaveletPacket2D
1616

1717

1818
def _compare_trees1(
@@ -236,7 +236,7 @@ def test_freq_order(level: int, wavelet_str: str, pywt_boundary: str) -> None:
236236
)
237237
# Get the full decomposition
238238
freq_tree = wp_tree.get_level(level, "freq")
239-
freq_order = get_freq_order(level)
239+
freq_order = WaveletPacket2D.get_freq_order(level)
240240

241241
for order_list, tree_list in zip(freq_tree, freq_order):
242242
for order_el, tree_el in zip(order_list, tree_list):

0 commit comments

Comments
 (0)