44import collections
55from functools import partial
66from itertools import product
7- from typing import TYPE_CHECKING , Callable , Dict , List , Optional , Tuple , Union
7+ from typing import TYPE_CHECKING , Callable , Dict , List , Optional , Tuple , Union , Any
88
99import pywt
1010import torch
@@ -84,7 +84,7 @@ def transform(
8484 self ._recursive_dwt (data , level = 0 , path = "" )
8585 return self
8686
87- def reconstruct (self ):
87+ def reconstruct (self ) -> None :
8888 """Recursively reconstruct the input starting from the leaf nodes.
8989
9090 Reconstruction replaces the input-data originally assigned to this object.
@@ -104,11 +104,14 @@ def reconstruct(self):
104104 >>> ptwp.reconstruct()
105105 >>> print(ptwp[""])
106106 """
107+ if self .maxlevel is None :
108+ self .maxlevel = pywt .dwt_maxlevel (self ["" ].shape [- 1 ], self .wavelet .dec_len )
109+
107110 for level in reversed (range (self .maxlevel )):
108111 for node in self .get_level (level ):
109112 data_a = self [node + "a" ]
110113 data_b = self [node + "d" ]
111- rec = self ._get_waverec (data_a .shape [- 1 ])(( data_a , data_b ) )
114+ rec = self ._get_waverec (data_a .shape [- 1 ])([ data_a , data_b ] )
112115 self [node ] = rec
113116
114117 def _get_wavedec (
@@ -127,7 +130,7 @@ def _get_wavedec(
127130 def _get_waverec (
128131 self ,
129132 length : int ,
130- ) -> Callable [[torch .Tensor ], List [ torch .Tensor ] ]:
133+ ) -> Callable [[List [ torch .Tensor ]], torch .Tensor ]:
131134 if self .mode == "boundary" :
132135 if length not in self ._matrix_waverec_dict .keys ():
133136 self ._matrix_waverec_dict [length ] = MatrixWaverec (
@@ -271,14 +274,19 @@ def transform(
271274 self ._recursive_dwt2d (data , level = 0 , path = "" )
272275 return self
273276
274- def reconstruct (self ):
277+ def reconstruct (self ) -> None :
275278 """Recursively reconstruct the input starting from the leaf nodes.
276279
277280 Note:
278281 Only changes to leaf node data impacts the results,
279282 since changes in all other nodes will be replaced with
280283 a reconstruction from the leafs.
281284 """
285+ if self .maxlevel is None :
286+ self .maxlevel = pywt .dwt_maxlevel (min (
287+ self ["" ].shape [- 2 :]),
288+ self .wavelet .dec_len )
289+
282290 for level in reversed (range (self .maxlevel )):
283291 for node in self .get_natural_order (level ):
284292 if self .mode == "boundary" :
@@ -300,7 +308,7 @@ def reconstruct(self):
300308 )
301309 self [node ] = rec .squeeze (1 )
302310
303- def get_natural_order (self , level : int ) -> list :
311+ def get_natural_order (self , level : int ) -> List [ str ] :
304312 """Get the natural ordering for a given decomposition level.
305313
306314 Args:
@@ -333,10 +341,8 @@ def _get_wavedec(
333341
334342 def _get_waverec (
335343 self , shape : Tuple [int , ...]
336- ) -> Callable [
337- [torch .Tensor ],
338- List [Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]]],
339- ]:
344+ ) -> Callable [[Any ], # TODO: Get the acutal type working.
345+ torch .Tensor ]:
340346 if self .mode == "boundary" :
341347 shape = tuple (shape )
342348 if shape not in self .matrix_waverec2_dict .keys ():
0 commit comments