Skip to content

Commit 5730752

Browse files
committed
typing
1 parent 78b8389 commit 5730752

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

src/ptwt/matmul_transform_2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ def __call__(
770770
)
771771

772772
batch_size = coefficients[-1][0].shape[0]
773-
ll = coefficients[0]
773+
ll: torch.Tensor = coefficients[0] # type: ignore
774774
if not isinstance(ll, torch.Tensor):
775775
raise ValueError(
776776
"First element of coeffs must be the approximation coefficient tensor."

src/ptwt/packets.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import collections
55
from functools import partial
66
from itertools import product
7-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
7+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union, Any
88

99
import pywt
1010
import torch
@@ -84,7 +84,7 @@ def transform(
8484
self._recursive_dwt(data, level=0, path="")
8585
return self
8686

87-
def reconstruct(self):
87+
def reconstruct(self) -> None:
8888
"""Recursively reconstruct the input starting from the leaf nodes.
8989
9090
Reconstruction replaces the input-data originally assigned to this object.
@@ -104,11 +104,14 @@ def reconstruct(self):
104104
>>> ptwp.reconstruct()
105105
>>> print(ptwp[""])
106106
"""
107+
if self.maxlevel is None:
108+
self.maxlevel = pywt.dwt_maxlevel(self[""].shape[-1], self.wavelet.dec_len)
109+
107110
for level in reversed(range(self.maxlevel)):
108111
for node in self.get_level(level):
109112
data_a = self[node + "a"]
110113
data_b = self[node + "d"]
111-
rec = self._get_waverec(data_a.shape[-1])((data_a, data_b))
114+
rec = self._get_waverec(data_a.shape[-1])([data_a, data_b])
112115
self[node] = rec
113116

114117
def _get_wavedec(
@@ -127,7 +130,7 @@ def _get_wavedec(
127130
def _get_waverec(
128131
self,
129132
length: int,
130-
) -> Callable[[torch.Tensor], List[torch.Tensor]]:
133+
) -> Callable[[List[torch.Tensor]], torch.Tensor]:
131134
if self.mode == "boundary":
132135
if length not in self._matrix_waverec_dict.keys():
133136
self._matrix_waverec_dict[length] = MatrixWaverec(
@@ -271,14 +274,19 @@ def transform(
271274
self._recursive_dwt2d(data, level=0, path="")
272275
return self
273276

274-
def reconstruct(self):
277+
def reconstruct(self) -> None:
275278
"""Recursively reconstruct the input starting from the leaf nodes.
276279
277280
Note:
278281
Only changes to leaf node data impacts the results,
279282
since changes in all other nodes will be replaced with
280283
a reconstruction from the leafs.
281284
"""
285+
if self.maxlevel is None:
286+
self.maxlevel = pywt.dwt_maxlevel(min(
287+
self[""].shape[-2:]),
288+
self.wavelet.dec_len)
289+
282290
for level in reversed(range(self.maxlevel)):
283291
for node in self.get_natural_order(level):
284292
if self.mode == "boundary":
@@ -300,7 +308,7 @@ def reconstruct(self):
300308
)
301309
self[node] = rec.squeeze(1)
302310

303-
def get_natural_order(self, level: int) -> list:
311+
def get_natural_order(self, level: int) -> List[str]:
304312
"""Get the natural ordering for a given decomposition level.
305313
306314
Args:
@@ -333,10 +341,8 @@ def _get_wavedec(
333341

334342
def _get_waverec(
335343
self, shape: Tuple[int, ...]
336-
) -> Callable[
337-
[torch.Tensor],
338-
List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]],
339-
]:
344+
) -> Callable[[Any], # TODO: Get the acutal type working.
345+
torch.Tensor]:
340346
if self.mode == "boundary":
341347
shape = tuple(shape)
342348
if shape not in self.matrix_waverec2_dict.keys():

0 commit comments

Comments
 (0)