Skip to content

Commit 3eb33a2

Browse files
committed
type update.
1 parent f775b44 commit 3eb33a2

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/ptwt/packets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import collections
55
from functools import partial
66
from itertools import product
7-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
7+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
88

99
import pywt
1010
import torch
@@ -364,7 +364,7 @@ def _get_wavedec(
364364

365365
def _get_waverec(
366366
self, shape: Tuple[int, ...]
367-
) -> Callable[[Any], torch.Tensor]: # TODO: Get the acutal type working.
367+
) -> Callable[[Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]], torch.Tensor]:
368368
if self.mode == "boundary":
369369
shape = tuple(shape)
370370
if shape not in self.matrix_waverec2_dict.keys():
@@ -374,7 +374,7 @@ def _get_waverec(
374374
separable=self.separable,
375375
)
376376
fun = self.matrix_waverec2_dict[shape]
377-
return fun
377+
return fun # type: ignore
378378
else:
379379
return partial(waverec2, wavelet=self.wavelet)
380380

0 commit comments

Comments
 (0)