@@ -122,7 +122,7 @@ def __init__(
122122
123123 def transform (
124124 self , data : torch .Tensor , maxlevel : Optional [int ] = None
125- ) -> " WaveletPacket" :
125+ ) -> WaveletPacket :
126126 """Calculate the 1d wavelet packet transform for the input data.
127127
128128 Args:
@@ -139,7 +139,7 @@ def transform(
139139 self ._recursive_dwt (data , level = 0 , path = "" )
140140 return self
141141
142- def reconstruct (self ) -> " WaveletPacket" :
142+ def reconstruct (self ) -> WaveletPacket :
143143 """Recursively reconstruct the input starting from the leaf nodes.
144144
145145 Reconstruction replaces the input data originally assigned to this object.
@@ -326,7 +326,7 @@ def __init__(
326326
327327 def transform (
328328 self , data : torch .Tensor , maxlevel : Optional [int ] = None
329- ) -> " WaveletPacket2D" :
329+ ) -> WaveletPacket2D :
330330 """Calculate the 2d wavelet packet transform for the input data.
331331
332332 The transform function allows reusing the same object.
@@ -350,7 +350,7 @@ def transform(
350350 self ._recursive_dwt2d (data , level = 0 , path = "" )
351351 return self
352352
353- def reconstruct (self ) -> " WaveletPacket2D" :
353+ def reconstruct (self ) -> WaveletPacket2D :
354354 """Recursively reconstruct the input starting from the leaf nodes.
355355
356356 Note:
@@ -364,7 +364,7 @@ def reconstruct(self) -> "WaveletPacket2D":
364364 )
365365
366366 for level in reversed (range (self .maxlevel )):
367- for node in self .get_natural_order (level ):
367+ for node in WaveletPacket2D .get_natural_order (level ):
368368 data_a = self [node + "a" ]
369369 data_h = self [node + "h" ]
370370 data_v = self [node + "v" ]
@@ -386,17 +386,6 @@ def reconstruct(self) -> "WaveletPacket2D":
386386 self [node ] = rec
387387 return self
388388
389- def get_natural_order (self , level : int ) -> list [str ]:
390- """Get the natural ordering for a given decomposition level.
391-
392- Args:
393- level (int): The decomposition level.
394-
395- Returns:
396- A list with the filter order strings.
397- """
398- return ["" .join (p ) for p in product (["a" , "h" , "v" , "d" ], repeat = level )]
399-
400389 def _get_wavedec (self , shape : tuple [int , ...]) -> Callable [
401390 [torch .Tensor ],
402391 WaveletCoeff2d ,
@@ -525,54 +514,68 @@ def __getitem__(self, key: str) -> torch.Tensor:
525514 )
526515 return super ().__getitem__ (key )
527516
517+ @staticmethod
518+ def get_natural_order (level : int ) -> list [str ]:
519+ """Get the natural ordering for a given decomposition level.
528520
529- def get_freq_order ( level : int ) -> list [ list [ tuple [ str , ...]]] :
530- """Get the frequency order for a given packet decomposition level.
521+ Args :
522+ level (int): The decomposition level.
531523
532- Use this code to create two-dimensional frequency orderings.
524+ Returns:
525+ A list with the filter order strings.
526+ """
527+ return ["" .join (p ) for p in product (["a" , "h" , "v" , "d" ], repeat = level )]
533528
534- Args:
535- level (int): The number of decomposition scales.
529+ @staticmethod
530+ def get_freq_order (level : int ) -> list [list [str ]]:
531+ """Get the frequency order for a given packet decomposition level.
536532
537- Returns:
538- A list with the tree nodes in frequency order.
539-
540- Note:
541- Adapted from:
542- https://github.com/PyWavelets/pywt/blob/master/pywt/_wavelet_packets.py
543-
544- The code elements denote the filter application order. The filters
545- are named following the pywt convention as:
546- a - LL, low-low coefficients
547- h - LH, low-high coefficients
548- v - HL, high-low coefficients
549- d - HH, high-high coefficients
550- """
551- wp_natural_path = product (["a" , "h" , "v" , "d" ], repeat = level )
533+ Use this code to create two-dimensional frequency orderings.
552534
553- def _get_graycode_order (level : int , x : str = "a" , y : str = "d" ) -> list [str ]:
554- graycode_order = [x , y ]
555- for _ in range (level - 1 ):
556- graycode_order = [x + path for path in graycode_order ] + [
557- y + path for path in graycode_order [::- 1 ]
558- ]
559- return graycode_order
560-
561- def _expand_2d_path (path : tuple [str , ...]) -> tuple [str , str ]:
562- expanded_paths = {"d" : "hh" , "h" : "hl" , "v" : "lh" , "a" : "ll" }
563- return (
564- "" .join ([expanded_paths [p ][0 ] for p in path ]),
565- "" .join ([expanded_paths [p ][1 ] for p in path ]),
566- )
567-
568- nodes_dict : dict [str , dict [str , tuple [str , ...]]] = {}
569- for (row_path , col_path ), node in [
570- (_expand_2d_path (node ), node ) for node in wp_natural_path
571- ]:
572- nodes_dict .setdefault (row_path , {})[col_path ] = node
573- graycode_order = _get_graycode_order (level , x = "l" , y = "h" )
574- nodes = [nodes_dict [path ] for path in graycode_order if path in nodes_dict ]
575- result = []
576- for row in nodes :
577- result .append ([row [path ] for path in graycode_order if path in row ])
578- return result
535+ Args:
536+ level (int): The number of decomposition scales.
537+
538+ Returns:
539+ A list with the tree nodes in frequency order.
540+
541+ Note:
542+ Adapted from:
543+ https://github.com/PyWavelets/pywt/blob/master/pywt/_wavelet_packets.py
544+
545+ The code elements denote the filter application order. The filters
546+ are named following the pywt convention as:
547+ a - LL, low-low coefficients
548+ h - LH, low-high coefficients
549+ v - HL, high-low coefficients
550+ d - HH, high-high coefficients
551+ """
552+ wp_natural_path = product (["a" , "h" , "v" , "d" ], repeat = level )
553+
554+ def _get_graycode_order (level : int , x : str = "a" , y : str = "d" ) -> list [str ]:
555+ graycode_order = [x , y ]
556+ for _ in range (level - 1 ):
557+ graycode_order = [x + path for path in graycode_order ] + [
558+ y + path for path in graycode_order [::- 1 ]
559+ ]
560+ return graycode_order
561+
562+ def _expand_2d_path (path : tuple [str , ...]) -> tuple [str , str ]:
563+ expanded_paths = {"d" : "hh" , "h" : "hl" , "v" : "lh" , "a" : "ll" }
564+ return (
565+ "" .join ([expanded_paths [p ][0 ] for p in path ]),
566+ "" .join ([expanded_paths [p ][1 ] for p in path ]),
567+ )
568+
569+ nodes_dict : dict [str , dict [str , tuple [str , ...]]] = {}
570+ for (row_path , col_path ), node in [
571+ (_expand_2d_path (node ), node ) for node in wp_natural_path
572+ ]:
573+ nodes_dict .setdefault (row_path , {})[col_path ] = node
574+ graycode_order = _get_graycode_order (level , x = "l" , y = "h" )
575+ nodes = [nodes_dict [path ] for path in graycode_order if path in nodes_dict ]
576+ result = []
577+ for row in nodes :
578+ result .append (
579+ ["" .join (row [path ]) for path in graycode_order if path in row ]
580+ )
581+ return result
0 commit comments