Skip to content

Commit 942ad60

Browse files
committed
Improve typing in JIT code
1 parent 4282c9d commit 942ad60

File tree

7 files changed

+52
-106
lines changed

7 files changed

+52
-106
lines changed

examples/speed_tests/timeitconv_1d.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,8 @@
99
import ptwt
1010

1111

12-
class WaveletTuple(NamedTuple):
13-
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""
14-
15-
dec_lo: torch.Tensor
16-
dec_hi: torch.Tensor
17-
rec_lo: torch.Tensor
18-
rec_hi: torch.Tensor
19-
20-
21-
def _set_up_wavelet_tuple(wavelet, dtype):
22-
return WaveletTuple(
23-
torch.tensor(wavelet.dec_lo).type(dtype),
24-
torch.tensor(wavelet.dec_hi).type(dtype),
25-
torch.tensor(wavelet.rec_lo).type(dtype),
26-
torch.tensor(wavelet.rec_hi).type(dtype),
27-
)
28-
29-
3012
def _jit_wavedec_fun(data, wavelet):
31-
return ptwt.wavedec(data, wavelet, "periodic", level=10)
13+
return ptwt.wavedec(data, wavelet, mode="periodic", level=10)
3214

3315

3416
if __name__ == "__main__":
@@ -56,7 +38,7 @@ def _jit_wavedec_fun(data, wavelet):
5638
end = time.perf_counter()
5739
ptwt_time_cpu.append(end - start)
5840

59-
wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
41+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
6042
jit_wavedec = torch.jit.trace(
6143
_jit_wavedec_fun,
6244
(data, wavelet),
@@ -81,7 +63,7 @@ def _jit_wavedec_fun(data, wavelet):
8163
end = time.perf_counter()
8264
ptwt_time_gpu.append(end - start)
8365

84-
wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
66+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
8567
jit_wavedec = torch.jit.trace(
8668
_jit_wavedec_fun,
8769
(data.cuda(), wavelet),

examples/speed_tests/timeitconv_2d.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,9 @@
99
import ptwt
1010

1111

12-
class WaveletTuple(NamedTuple):
13-
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""
14-
15-
dec_lo: torch.Tensor
16-
dec_hi: torch.Tensor
17-
rec_lo: torch.Tensor
18-
rec_hi: torch.Tensor
19-
20-
21-
def _set_up_wavelet_tuple(wavelet, dtype):
22-
return WaveletTuple(
23-
torch.tensor(wavelet.dec_lo).type(dtype),
24-
torch.tensor(wavelet.dec_hi).type(dtype),
25-
torch.tensor(wavelet.rec_lo).type(dtype),
26-
torch.tensor(wavelet.rec_hi).type(dtype),
27-
)
28-
29-
30-
def _to_jit_wavedec_2(data, wavelet):
12+
def _to_jit_wavedec_2(data: torch.Tensor, wavelet) -> list[torch.Tensor]:
3113
"""Ensure uniform datatypes in lists for the tracer.
32-
Going from List[Union[torch.Tensor, List[torch.Tensor]]] to List[torch.Tensor]
14+
Going from list[Union[torch.Tensor, list[torch.Tensor]]] to list[torch.Tensor]
3315
means we have to stack the lists in the output.
3416
"""
3517
assert data.shape == (32, 1e3, 1e3), "Changing the chape requires re-tracing."
@@ -79,7 +61,7 @@ def _to_jit_wavedec_2(data, wavelet):
7961

8062
ptwt_time_gpu.append(end - start)
8163

82-
wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
64+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
8365
jit_wavedec = torch.jit.trace(
8466
_to_jit_wavedec_2,
8567
(data.cuda(), wavelet),

examples/speed_tests/timeitconv_2d_separable.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,13 @@
1010
import ptwt
1111

1212

13-
class WaveletTuple(NamedTuple):
14-
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""
15-
16-
dec_lo: torch.Tensor
17-
dec_hi: torch.Tensor
18-
rec_lo: torch.Tensor
19-
rec_hi: torch.Tensor
20-
21-
22-
def _set_up_wavelet_tuple(wavelet, dtype):
23-
return WaveletTuple(
24-
torch.tensor(wavelet.dec_lo).type(dtype),
25-
torch.tensor(wavelet.dec_hi).type(dtype),
26-
torch.tensor(wavelet.rec_lo).type(dtype),
27-
torch.tensor(wavelet.rec_hi).type(dtype),
28-
)
29-
30-
3113
def _to_jit_wavedec_2(data, wavelet):
3214
"""Ensure uniform datatypes in lists for the tracer.
3315
Going from List[Union[torch.Tensor, List[torch.Tensor]]] to List[torch.Tensor]
3416
means we have to stack the lists in the output.
3517
"""
3618
assert data.shape == (32, 1e3, 1e3), "Changing the chape requires re-tracing."
37-
coeff = ptwt.fswavedec2(data, wavelet, "reflect", level=5)
19+
coeff = ptwt.fswavedec2(data, wavelet, mode="reflect", level=5)
3820
coeff2 = []
3921
for c in coeff:
4022
if isinstance(c, torch.Tensor):
@@ -103,7 +85,7 @@ def _to_jit_wavedec_2(data, wavelet):
10385
end = time.perf_counter()
10486
ptwt_time_gpu.append(end - start)
10587

106-
wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
88+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
10789
jit_wavedec = torch.jit.trace(
10890
_to_jit_wavedec_2,
10991
(data.cuda(), wavelet),

examples/speed_tests/timeitconv_3d.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,6 @@
99
import ptwt
1010

1111

12-
class WaveletTuple(NamedTuple):
13-
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""
14-
15-
dec_lo: torch.Tensor
16-
dec_hi: torch.Tensor
17-
rec_lo: torch.Tensor
18-
rec_hi: torch.Tensor
19-
20-
21-
def _set_up_wavelet_tuple(wavelet, dtype):
22-
return WaveletTuple(
23-
torch.tensor(wavelet.dec_lo).type(dtype),
24-
torch.tensor(wavelet.dec_hi).type(dtype),
25-
torch.tensor(wavelet.rec_lo).type(dtype),
26-
torch.tensor(wavelet.rec_hi).type(dtype),
27-
)
28-
29-
3012
def _to_jit_wavedec_3(data, wavelet):
3113
"""Ensure uniform datatypes in lists for the tracer.
3214
@@ -85,7 +67,7 @@ def _to_jit_wavedec_3(data, wavelet):
8567
end = time.perf_counter()
8668
ptwt_time_gpu.append(end - start)
8769

88-
wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
70+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
8971
jit_wavedec = torch.jit.trace(
9072
_to_jit_wavedec_3,
9173
(data.cuda(), wavelet),

src/ptwt/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Differentiable and gpu enabled fast wavelet transforms in PyTorch."""
22

3-
from ._util import Wavelet
4-
from .constants import WaveletCoeff2d, WaveletCoeffNd, WaveletCoeff2dSeparable
3+
from ._util import Wavelet, WaveletTensorTuple
4+
from .constants import WaveletCoeff2d, WaveletCoeff2dSeparable, WaveletCoeffNd
55
from .continuous_transform import cwt
66
from .conv_transform import wavedec, waverec
77
from .conv_transform_2 import wavedec2, waverec2

src/ptwt/_util.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import typing
66
from collections.abc import Sequence
7-
from typing import Any, Callable, Optional, Protocol, Union, cast, overload
7+
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union, cast, overload
88

99
import numpy as np
1010
import pywt
@@ -38,6 +38,42 @@ def __len__(self) -> int:
3838
return len(self.dec_lo)
3939

4040

41+
class WaveletTensorTuple(NamedTuple):
42+
"""Named tuple containing the wavelet filter bank to use in JIT code."""
43+
44+
dec_lo: torch.Tensor
45+
dec_hi: torch.Tensor
46+
rec_lo: torch.Tensor
47+
rec_hi: torch.Tensor
48+
49+
@property
50+
def dec_len(self) -> int:
51+
"""Length of decomposition filters."""
52+
return len(self.dec_lo)
53+
54+
@property
55+
def rec_len(self) -> int:
56+
"""Length of reconstruction filters."""
57+
return len(self.rec_lo)
58+
59+
@property
60+
def filter_bank(
61+
self,
62+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
63+
"""Filter bank of the wavelet."""
64+
return self
65+
66+
@classmethod
67+
def from_wavelet(cls, wavelet: Wavelet, dtype: torch.dtype) -> WaveletTensorTuple:
68+
"""Construct Wavelet named tuple from wavelet protocol member."""
69+
return cls(
70+
torch.tensor(wavelet.dec_lo, dtype=dtype),
71+
torch.tensor(wavelet.dec_hi, dtype=dtype),
72+
torch.tensor(wavelet.rec_lo, dtype=dtype),
73+
torch.tensor(wavelet.rec_hi, dtype=dtype),
74+
)
75+
76+
4177
def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet:
4278
"""Ensure the input argument to be a pywt wavelet compatible object.
4379

tests/test_jit.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Ensure pytorch's torch.jit.trace feature works properly."""
22

3-
from typing import NamedTuple, Optional, Union
3+
from typing import Optional, Union
44

55
import numpy as np
66
import pytest
@@ -13,24 +13,6 @@
1313
from tests._mackey_glass import MackeyGenerator
1414

1515

16-
class WaveletTuple(NamedTuple):
17-
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""
18-
19-
dec_lo: torch.Tensor
20-
dec_hi: torch.Tensor
21-
rec_lo: torch.Tensor
22-
rec_hi: torch.Tensor
23-
24-
25-
def _set_up_wavelet_tuple(wavelet: WaveletTuple, dtype: torch.dtype) -> WaveletTuple:
26-
return WaveletTuple(
27-
torch.tensor(wavelet.dec_lo).type(dtype),
28-
torch.tensor(wavelet.dec_hi).type(dtype),
29-
torch.tensor(wavelet.rec_lo).type(dtype),
30-
torch.tensor(wavelet.rec_hi).type(dtype),
31-
)
32-
33-
3416
def _to_jit_wavedec_fun(
3517
data: torch.Tensor, wavelet: Union[ptwt.Wavelet, str], level: Optional[int]
3618
) -> list[torch.Tensor]:
@@ -53,7 +35,7 @@ def test_conv_fwt_jit(
5335

5436
mackey_data_1 = torch.squeeze(generator(), -1).type(dtype)
5537
wavelet = pywt.Wavelet(wavelet_string)
56-
wavelet = _set_up_wavelet_tuple(wavelet, dtype)
38+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(wavelet, dtype)
5739

5840
with pytest.warns(Warning):
5941
jit_wavedec = torch.jit.trace( # type: ignore
@@ -105,7 +87,7 @@ def test_conv_fwt_jit_2d() -> None:
10587
rec = _to_jit_waverec_2(coeff, wavelet)
10688
assert np.allclose(rec.squeeze(1).numpy(), data.numpy())
10789

108-
wavelet = _set_up_wavelet_tuple(wavelet, dtype=torch.float64)
90+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(wavelet, dtype=torch.float64)
10991
with pytest.warns(Warning):
11092
jit_wavedec2 = torch.jit.trace( # type: ignore
11193
_to_jit_wavedec_2,
@@ -159,7 +141,7 @@ def test_conv_fwt_jit_3d() -> None:
159141
rec = _to_jit_waverec_3(coeff, wavelet)
160142
assert np.allclose(rec.squeeze(1).numpy(), data.numpy())
161143

162-
wavelet = _set_up_wavelet_tuple(wavelet, dtype=torch.float64)
144+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(wavelet, dtype=torch.float64)
163145
with pytest.warns(Warning):
164146
jit_wavedec3 = torch.jit.trace( # type: ignore
165147
_to_jit_wavedec_3,

0 commit comments

Comments
 (0)