1010import torch
1111
1212from ._util import Wavelet , _as_wavelet
13- from .conv_transform import wavedec
14- from .conv_transform_2 import wavedec2
15- from .matmul_transform import MatrixWavedec
16- from .matmul_transform_2 import MatrixWavedec2
13+ from .conv_transform import wavedec , waverec
14+ from .conv_transform_2 import wavedec2 , waverec2
15+ from .matmul_transform import MatrixWavedec , MatrixWaverec
16+ from .matmul_transform_2 import MatrixWavedec2 , MatrixWaverec2
1717
1818if TYPE_CHECKING :
1919 BaseDict = collections .UserDict [str , torch .Tensor ]
@@ -30,7 +30,7 @@ def __init__(
3030 wavelet : Union [Wavelet , str ],
3131 mode : str = "reflect" ,
3232 boundary_orthogonalization : str = "qr" ,
33- max_level : Optional [int ] = None ,
33+ maxlevel : Optional [int ] = None ,
3434 ) -> None :
3535 """Create a wavelet packet decomposition object.
3636
@@ -47,42 +47,91 @@ def __init__(
4747 boundary_orthogonalization (str): The orthogonalization method
4848 to use. Only used if `mode` equals 'boundary'. Choose from
4949 'qr' or 'gramschmidt'. Defaults to 'qr'.
50- max_level (int, optional): Value is passed on to `transform`.
50+ maxlevel (int, optional): Value is passed on to `transform`.
5151 The highest decomposition level to compute. If None, the maximum level
5252 is determined from the input data shape. Defaults to None.
53+
54+ Example:
55+ >>> import torch, pywt, ptwt
56+ >>> import numpy as np
57+ >>> import scipy.signal
58+ >>> import matplotlib.pyplot as plt
59+ >>> t = np.linspace(0, 10, 1500)
60+ >>> w = scipy.signal.chirp(t, f0=1, f1=50, t1=10, method="linear")
61+ >>> wp = ptwt.WaveletPacket(data=torch.from_numpy(w.astype(np.float32)),
62+ wavelet=pywt.Wavelet("db3"), mode="reflect")
63+ >>> np_lst = []
64+ >>> for node in wp.get_level(5):
65+ >>> np_lst.append(wp[node])
66+ >>> viz = np.stack(np_lst).squeeze()
67+ >>> plt.imshow(np.abs(viz))
68+ >>> plt.show()
69+
5370 """
5471 self .wavelet = _as_wavelet (wavelet )
5572 self .mode = mode
5673 self .boundary = boundary_orthogonalization
5774 self ._matrix_wavedec_dict : Dict [int , MatrixWavedec ] = {}
58- self .max_level : Optional [int ] = None
75+ self ._matrix_waverec_dict : Dict [int , MatrixWaverec ] = {}
76+ self .maxlevel : Optional [int ] = None
5977 if data is not None :
6078 if len (data .shape ) == 1 :
6179 # add a batch dimension.
6280 data = data .unsqueeze (0 )
63- self .transform (data , max_level )
81+ self .transform (data , maxlevel )
6482 else :
6583 self .data = {}
6684
6785 def transform (
68- self , data : torch .Tensor , max_level : Optional [int ] = None
86+ self , data : torch .Tensor , maxlevel : Optional [int ] = None
6987 ) -> "WaveletPacket" :
7088 """Calculate the 1d wavelet packet transform for the input data.
7189
7290 Args:
7391 data (torch.Tensor): The input data array of shape [time]
7492 or [batch_size, time].
75- max_level (int, optional): The highest decomposition level to compute.
93+ maxlevel (int, optional): The highest decomposition level to compute.
7694 If None, the maximum level is determined from the input data shape.
7795 Defaults to None.
7896 """
7997 self .data = {}
80- if max_level is None :
81- max_level = pywt .dwt_max_level (data .shape [- 1 ], self .wavelet .dec_len )
82- self .max_level = max_level
98+ if maxlevel is None :
99+ maxlevel = pywt .dwt_max_level (data .shape [- 1 ], self .wavelet .dec_len )
100+ self .maxlevel = maxlevel
83101 self ._recursive_dwt (data , level = 0 , path = "" )
84102 return self
85103
104+ def reconstruct (self ) -> "WaveletPacket" :
105+ """Recursively reconstruct the input starting from the leaf nodes.
106+
107+ Reconstruction replaces the input-data originally assigned to this object.
108+
109+ Note:
110+ Only changes to leaf node data impacts the results,
111+ since changes in all other nodes will be replaced with
112+ a reconstruction from the leafs.
113+
114+ Example:
115+ >>> import numpy as np
116+ >>> import ptwt, torch
117+ >>> signal = np.random.randn(1, 16)
118+ >>> ptwp = ptwt.WaveletPacket(torch.from_numpy(signal), "haar",
119+ mode="boundary", maxlevel=2)
120+ >>> ptwp["aa"].data *= 0
121+ >>> ptwp.reconstruct()
122+ >>> print(ptwp[""])
123+ """
124+ if self .maxlevel is None :
125+ self .maxlevel = pywt .dwt_max_level (self ["" ].shape [- 1 ], self .wavelet .dec_len )
126+
127+ for level in reversed (range (self .maxlevel )):
128+ for node in self .get_level (level ):
129+ data_a = self [node + "a" ]
130+ data_b = self [node + "d" ]
131+ rec = self ._get_waverec (data_a .shape [- 1 ])([data_a , data_b ])
132+ self [node ] = rec
133+ return self
134+
86135 def _get_wavedec (
87136 self ,
88137 length : int ,
@@ -96,6 +145,19 @@ def _get_wavedec(
96145 else :
97146 return partial (wavedec , wavelet = self .wavelet , level = 1 , mode = self .mode )
98147
148+ def _get_waverec (
149+ self ,
150+ length : int ,
151+ ) -> Callable [[List [torch .Tensor ]], torch .Tensor ]:
152+ if self .mode == "boundary" :
153+ if length not in self ._matrix_waverec_dict .keys ():
154+ self ._matrix_waverec_dict [length ] = MatrixWaverec (
155+ self .wavelet , boundary = self .boundary
156+ )
157+ return self ._matrix_waverec_dict [length ]
158+ else :
159+ return partial (waverec , wavelet = self .wavelet )
160+
99161 def get_level (self , level : int ) -> List [str ]:
100162 """Return the graycode ordered paths to the filter tree nodes.
101163
@@ -113,10 +175,13 @@ def _get_graycode_order(self, level: int, x: str = "a", y: str = "d") -> List[st
113175 graycode_order = [x + path for path in graycode_order ] + [
114176 y + path for path in graycode_order [::- 1 ]
115177 ]
116- return graycode_order
178+ if level == 0 :
179+ return ["" ]
180+ else :
181+ return graycode_order
117182
118183 def _recursive_dwt (self , data : torch .Tensor , level : int , path : str ) -> None :
119- if not self .max_level :
184+ if not self .maxlevel :
120185 raise AssertionError
121186
122187 # TODO: This is a workaround since the convolutional transforms insert a
@@ -125,7 +190,7 @@ def _recursive_dwt(self, data: torch.Tensor, level: int, path: str) -> None:
125190 data = data .squeeze (1 )
126191
127192 self .data [path ] = data
128- if level < self .max_level :
193+ if level < self .maxlevel :
129194 res_lo , res_hi = self ._get_wavedec (data .shape [- 1 ])(data )
130195 self ._recursive_dwt (res_lo , level + 1 , path + "a" )
131196 self ._recursive_dwt (res_hi , level + 1 , path + "d" )
@@ -144,22 +209,26 @@ def __getitem__(self, key: str) -> torch.Tensor:
144209 ValueError: If the wavelet packet tree is not initialized.
145210 KeyError: If no wavelet coefficients are indexed by the specified key.
146211 """
147- if self .max_level is None :
212+ if self .maxlevel is None :
148213 raise ValueError (
149214 "The wavelet packet tree must be initialized via 'transform' before "
150215 "its values can be accessed!"
151216 )
152- if key not in self and len (key ) > self .max_level :
217+ if key not in self and len (key ) > self .maxlevel :
153218 raise KeyError (
154219 f"The requested level { len (key )} with key '{ key } ' is too large and "
155220 "cannot be accessed! This wavelet packet tree is initialized with "
156- f"maximum level { self .max_level } ."
221+ f"maximum level { self .maxlevel } ."
157222 )
158223 return super ().__getitem__ (key )
159224
160225
161226class WaveletPacket2D (BaseDict ):
162- """Two dimensional wavelet packets."""
227+ """Two dimensional wavelet packets.
228+
229+ Example code illustrating the use of this class is available at:
230+ https://github.com/v0lta/PyTorch-Wavelet-Toolbox/tree/main/examples/deepfake_analysis
231+ """
163232
164233 def __init__ (
165234 self ,
@@ -168,7 +237,7 @@ def __init__(
168237 mode : str = "reflect" ,
169238 boundary_orthogonalization : str = "qr" ,
170239 separable : bool = False ,
171- max_level : Optional [int ] = None ,
240+ maxlevel : Optional [int ] = None ,
172241 ) -> None :
173242 """Create a 2D-Wavelet packet tree.
174243
@@ -188,7 +257,7 @@ def __init__(
188257 separable (bool): If true and the sparse matrix backend is selected,
189258 a separable transform is performed, i.e. each image axis is
190259 transformed separately. Defaults to False.
191- max_level (int, optional): Value is passed on to `transform`.
260+ maxlevel (int, optional): Value is passed on to `transform`.
192261 The highest decomposition level to compute. If None, the maximum level
193262 is determined from the input data shape. Defaults to None.
194263 """
@@ -197,15 +266,16 @@ def __init__(
197266 self .boundary = boundary_orthogonalization
198267 self .separable = separable
199268 self .matrix_wavedec2_dict : Dict [Tuple [int , ...], MatrixWavedec2 ] = {}
269+ self .matrix_waverec2_dict : Dict [Tuple [int , ...], MatrixWaverec2 ] = {}
200270
201- self .max_level : Optional [int ] = None
271+ self .maxlevel : Optional [int ] = None
202272 if data is not None :
203- self .transform (data , max_level )
273+ self .transform (data , maxlevel )
204274 else :
205275 self .data = {}
206276
207277 def transform (
208- self , data : torch .Tensor , max_level : Optional [int ] = None
278+ self , data : torch .Tensor , maxlevel : Optional [int ] = None
209279 ) -> "WaveletPacket2D" :
210280 """Calculate the 2d wavelet packet transform for the input data.
211281
@@ -214,18 +284,64 @@ def transform(
214284 Args:
215285 data (torch.tensor): The input data tensor
216286 of shape [batch_size, height, width]
217- max_level (int, optional): The highest decomposition level to compute.
287+ maxlevel (int, optional): The highest decomposition level to compute.
218288 If None, the maximum level is determined from the input data shape.
219289 Defaults to None.
220290 """
221291 self .data = {}
222- if max_level is None :
223- max_level = pywt .dwt_max_level (min (data .shape [- 2 :]), self .wavelet .dec_len )
224- self .max_level = max_level
292+ if maxlevel is None :
293+ maxlevel = pywt .dwt_max_level (min (data .shape [- 2 :]), self .wavelet .dec_len )
294+ self .maxlevel = maxlevel
225295
226296 self ._recursive_dwt2d (data , level = 0 , path = "" )
227297 return self
228298
299+ def reconstruct (self ) -> "WaveletPacket2D" :
300+ """Recursively reconstruct the input starting from the leaf nodes.
301+
302+ Note:
303+ Only changes to leaf node data impacts the results,
304+ since changes in all other nodes will be replaced with
305+ a reconstruction from the leafs.
306+ """
307+ if self .maxlevel is None :
308+ self .maxlevel = pywt .dwt_max_level (
309+ min (self ["" ].shape [- 2 :]), self .wavelet .dec_len
310+ )
311+
312+ for level in reversed (range (self .maxlevel )):
313+ for node in self .get_natural_order (level ):
314+ if self .mode == "boundary" :
315+ data_a = self [node + "a" ]
316+ data_h = self [node + "h" ]
317+ data_v = self [node + "v" ]
318+ data_d = self [node + "d" ]
319+ rec = self ._get_waverec (data_a .shape [- 2 :])(
320+ (data_a , (data_h , data_v , data_d ))
321+ )
322+ self [node ] = rec
323+ else :
324+ data_a = self [node + "a" ].unsqueeze (1 )
325+ data_h = self [node + "h" ].unsqueeze (1 )
326+ data_v = self [node + "v" ].unsqueeze (1 )
327+ data_d = self [node + "d" ].unsqueeze (1 )
328+ rec = self ._get_waverec (data_a .shape [- 2 :])(
329+ (data_a , (data_h , data_v , data_d ))
330+ )
331+ self [node ] = rec .squeeze (1 )
332+ return self
333+
334+ def get_natural_order (self , level : int ) -> List [str ]:
335+ """Get the natural ordering for a given decomposition level.
336+
337+ Args:
338+ level (int): The decomposition level.
339+
340+ Returns:
341+ list: A list with the filter order strings.
342+ """
343+ return ["" .join (p ) for p in list (product (["a" , "h" , "v" , "d" ], repeat = level ))]
344+
229345 def _get_wavedec (
230346 self , shape : Tuple [int , ...]
231347 ) -> Callable [
@@ -246,8 +362,27 @@ def _get_wavedec(
246362 else :
247363 return partial (wavedec2 , wavelet = self .wavelet , level = 1 , mode = self .mode )
248364
365+ def _get_waverec (
366+ self , shape : Tuple [int , ...]
367+ ) -> Callable [
368+ [Tuple [torch .Tensor , Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]]],
369+ torch .Tensor ,
370+ ]:
371+ if self .mode == "boundary" :
372+ shape = tuple (shape )
373+ if shape not in self .matrix_waverec2_dict .keys ():
374+ self .matrix_waverec2_dict [shape ] = MatrixWaverec2 (
375+ self .wavelet ,
376+ boundary = self .boundary ,
377+ separable = self .separable ,
378+ )
379+ fun = self .matrix_waverec2_dict [shape ]
380+ return fun # type: ignore
381+ else :
382+ return partial (waverec2 , wavelet = self .wavelet )
383+
249384 def _recursive_dwt2d (self , data : torch .Tensor , level : int , path : str ) -> None :
250- if not self .max_level :
385+ if not self .maxlevel :
251386 raise AssertionError
252387
253388 # TODO: This is a workaround since the convolutional transforms insert a
@@ -256,7 +391,7 @@ def _recursive_dwt2d(self, data: torch.Tensor, level: int, path: str) -> None:
256391 data = data .squeeze (1 )
257392
258393 self .data [path ] = data
259- if level < self .max_level :
394+ if level < self .maxlevel :
260395 result_a , (result_h , result_v , result_d ) = self ._get_wavedec (
261396 data .shape [- 2 :]
262397 )(data )
@@ -281,16 +416,16 @@ def __getitem__(self, key: str) -> torch.Tensor:
281416 ValueError: If the wavelet packet tree is not initialized.
282417 KeyError: If no wavelet coefficients are indexed by the specified key.
283418 """
284- if self .max_level is None :
419+ if self .maxlevel is None :
285420 raise ValueError (
286421 "The wavelet packet tree must be initialized via 'transform' before "
287422 "its values can be accessed!"
288423 )
289- if key not in self and len (key ) > self .max_level :
424+ if key not in self and len (key ) > self .maxlevel :
290425 raise KeyError (
291426 f"The requested level { len (key )} with key '{ key } ' is too large and "
292427 "cannot be accessed! This wavelet packet tree is initialized with "
293- f"maximum level { self .max_level } ."
428+ f"maximum level { self .maxlevel } ."
294429 )
295430 return super ().__getitem__ (key )
296431
0 commit comments