Skip to content

Commit c256d9e

Browse files
authored
Merge pull request #87 from v0lta/improve-typing
Improve typing and docstrings
2 parents 78abc5f + c973350 commit c256d9e

36 files changed

+900
-882
lines changed

.flake8

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ ignore =
2323
# asserts are ok in test.
2424
S101
2525
C901
26+
extend-select = B950
27+
extend-ignore = E501,E701,E704
2628
exclude =
2729
.tox,
2830
.git,
@@ -37,7 +39,7 @@ exclude =
3739
.eggs,
3840
data.
3941
src/ptwt/__init__.py
40-
max-line-length = 90
42+
max-line-length = 80
4143
max-complexity = 20
4244
import-order-style = pycharm
4345
application-import-names =

docs/conf.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@
7070
html_favicon = "_static/favicon.ico"
7171
html_logo = "_static/shannon.png"
7272

73-
html_favicon = "favicon/favicon.ico"
74-
7573
# Add any paths that contain custom static files (such as style sheets) here,
7674
# relative to this directory. They are copied after the builtin static files,
7775
# so a file named "default.css" will overwrite the builtin "default.css".
@@ -82,3 +80,10 @@
8280

8381
# numbered figures
8482
numfig = True
83+
84+
autodoc_type_aliases = {
85+
"WaveletCoeff2d": "ptwt.constants.WaveletCoeff2d",
86+
"WaveletCoeff2dSeparable": "ptwt.constants.WaveletCoeff2dSeparable",
87+
"WaveletCoeffNd": "ptwt.constants.WaveletCoeffNd",
88+
"BaseMatrixWaveDec": "ptwt.matmul_transform.BaseMatrixWaveDec",
89+
}

docs/ptwt.rst

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ ptwt.packets module
3232

3333
.. automodule:: ptwt.packets
3434
:members:
35+
:special-members: __getitem__
3536
:undoc-members:
3637
:show-inheritance:
3738

@@ -68,6 +69,7 @@ ptwt.matmul\_transform module
6869

6970
.. automodule:: ptwt.matmul_transform
7071
:members:
72+
:special-members: __call__
7173
:undoc-members:
7274
:show-inheritance:
7375

@@ -76,6 +78,7 @@ ptwt.matmul\_transform\_2 module
7678

7779
.. automodule:: ptwt.matmul_transform_2
7880
:members:
81+
:special-members: __call__
7982
:undoc-members:
8083
:show-inheritance:
8184

@@ -84,6 +87,7 @@ ptwt.matmul\_transform\_3 module
8487

8588
.. automodule:: ptwt.matmul_transform_3
8689
:members:
90+
:special-members: __call__
8791
:undoc-members:
8892
:show-inheritance:
8993

@@ -96,14 +100,6 @@ ptwt.sparse\_math module
96100
:undoc-members:
97101
:show-inheritance:
98102

99-
ptwt.version module
100-
-------------------
101-
102-
.. automodule:: ptwt.version
103-
:members:
104-
:undoc-members:
105-
:show-inheritance:
106-
107103
ptwt.wavelets\_learnable module
108104
-------------------------------
109105

@@ -118,3 +114,11 @@ ptwt.constants
118114
:members:
119115
:undoc-members:
120116
:show-inheritance:
117+
118+
ptwt.version module
119+
-------------------
120+
121+
.. automodule:: ptwt.version
122+
:members:
123+
:undoc-members:
124+
:show-inheritance:

examples/deepfake_analysis/packet_plot.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,6 @@
1111
import ptwt
1212

1313

14-
def get_freq_order(level: int):
15-
"""Get the frequency order for a given packet decomposition level.
16-
Adapted from:
17-
https://github.com/PyWavelets/pywt/blob/master/pywt/_wavelet_packets.py
18-
The code elements denote the filter application order. The filters
19-
are named following the pywt convention as:
20-
a - LL, low-low coefficients
21-
h - LH, low-high coefficients
22-
v - HL, high-low coefficients
23-
d - HH, high-high coefficients
24-
"""
25-
wp_natural_path = list(product(["a", "h", "v", "d"], repeat=level))
26-
27-
def _get_graycode_order(level, x="a", y="d"):
28-
graycode_order = [x, y]
29-
for _ in range(level - 1):
30-
graycode_order = [x + path for path in graycode_order] + [
31-
y + path for path in graycode_order[::-1]
32-
]
33-
return graycode_order
34-
35-
def _expand_2d_path(path):
36-
expanded_paths = {"d": "hh", "h": "hl", "v": "lh", "a": "ll"}
37-
return (
38-
"".join([expanded_paths[p][0] for p in path]),
39-
"".join([expanded_paths[p][1] for p in path]),
40-
)
41-
42-
nodes: dict = {}
43-
for (row_path, col_path), node in [
44-
(_expand_2d_path(node), node) for node in wp_natural_path
45-
]:
46-
nodes.setdefault(row_path, {})[col_path] = node
47-
graycode_order = _get_graycode_order(level, x="l", y="h")
48-
nodes_list: list = [nodes[path] for path in graycode_order if path in nodes]
49-
wp_frequency_path = []
50-
for row in nodes_list:
51-
wp_frequency_path.append([row[path] for path in graycode_order if path in row])
52-
return wp_frequency_path, wp_natural_path
53-
54-
5514
def generate_frequency_packet_image(packet_array: np.ndarray, degree: int):
5615
"""Create a ready-to-polt image with frequency-order packages.
5716
Given a packet array in natural order, creat an image which is
@@ -63,7 +22,8 @@ def generate_frequency_packet_image(packet_array: np.ndarray, degree: int):
6322
Returns:
6423
[np.ndarray]: The image of shape [original_height, original_width]
6524
"""
66-
wp_freq_path, wp_natural_path = get_freq_order(degree)
25+
wp_freq_path = ptwt.WaveletPacket2D.get_freq_order(degree)
26+
wp_natural_path = ptwt.WaveletPacket2D.get_natural_order(degree)
6727

6828
image = []
6929
# go through the rows.
@@ -107,7 +67,8 @@ def load_images(path: str) -> list:
10767

10868

10969
if __name__ == "__main__":
110-
frequency_path, natural_path = get_freq_order(level=3)
70+
freq_path = ptwt.WaveletPacket2D.get_freq_order(level=3)
71+
frequency_path = ptwt.WaveletPacket2D.get_natural_order(level=3)
11172
print("Loading ffhq images:")
11273
ffhq_images = load_images("./ffhq_style_gan/source_data/A_ffhq")
11374
print("processing ffhq")

examples/speed_tests/timeitconv_1d.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,8 @@
99
import ptwt
1010

1111

12-
class WaveletTuple(NamedTuple):
13-
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""
14-
15-
dec_lo: torch.Tensor
16-
dec_hi: torch.Tensor
17-
rec_lo: torch.Tensor
18-
rec_hi: torch.Tensor
19-
20-
21-
def _set_up_wavelet_tuple(wavelet, dtype):
22-
return WaveletTuple(
23-
torch.tensor(wavelet.dec_lo).type(dtype),
24-
torch.tensor(wavelet.dec_hi).type(dtype),
25-
torch.tensor(wavelet.rec_lo).type(dtype),
26-
torch.tensor(wavelet.rec_hi).type(dtype),
27-
)
28-
29-
3012
def _jit_wavedec_fun(data, wavelet):
31-
return ptwt.wavedec(data, wavelet, "periodic", level=10)
13+
return ptwt.wavedec(data, wavelet, mode="periodic", level=10)
3214

3315

3416
if __name__ == "__main__":
@@ -56,7 +38,7 @@ def _jit_wavedec_fun(data, wavelet):
5638
end = time.perf_counter()
5739
ptwt_time_cpu.append(end - start)
5840

59-
wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
41+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
6042
jit_wavedec = torch.jit.trace(
6143
_jit_wavedec_fun,
6244
(data, wavelet),
@@ -81,7 +63,7 @@ def _jit_wavedec_fun(data, wavelet):
8163
end = time.perf_counter()
8264
ptwt_time_gpu.append(end - start)
8365

84-
wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
66+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
8567
jit_wavedec = torch.jit.trace(
8668
_jit_wavedec_fun,
8769
(data.cuda(), wavelet),

examples/speed_tests/timeitconv_2d.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,9 @@
99
import ptwt
1010

1111

12-
class WaveletTuple(NamedTuple):
13-
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""
14-
15-
dec_lo: torch.Tensor
16-
dec_hi: torch.Tensor
17-
rec_lo: torch.Tensor
18-
rec_hi: torch.Tensor
19-
20-
21-
def _set_up_wavelet_tuple(wavelet, dtype):
22-
return WaveletTuple(
23-
torch.tensor(wavelet.dec_lo).type(dtype),
24-
torch.tensor(wavelet.dec_hi).type(dtype),
25-
torch.tensor(wavelet.rec_lo).type(dtype),
26-
torch.tensor(wavelet.rec_hi).type(dtype),
27-
)
28-
29-
30-
def _to_jit_wavedec_2(data, wavelet):
12+
def _to_jit_wavedec_2(data: torch.Tensor, wavelet) -> list[torch.Tensor]:
3113
"""Ensure uniform datatypes in lists for the tracer.
32-
Going from List[Union[torch.Tensor, List[torch.Tensor]]] to List[torch.Tensor]
14+
Going from list[Union[torch.Tensor, list[torch.Tensor]]] to list[torch.Tensor]
3315
means we have to stack the lists in the output.
3416
"""
3517
assert data.shape == (32, 1e3, 1e3), "Changing the chape requires re-tracing."
@@ -79,7 +61,7 @@ def _to_jit_wavedec_2(data, wavelet):
7961

8062
ptwt_time_gpu.append(end - start)
8163

82-
wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
64+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
8365
jit_wavedec = torch.jit.trace(
8466
_to_jit_wavedec_2,
8567
(data.cuda(), wavelet),

examples/speed_tests/timeitconv_2d_separable.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,13 @@
1010
import ptwt
1111

1212

13-
class WaveletTuple(NamedTuple):
14-
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""
15-
16-
dec_lo: torch.Tensor
17-
dec_hi: torch.Tensor
18-
rec_lo: torch.Tensor
19-
rec_hi: torch.Tensor
20-
21-
22-
def _set_up_wavelet_tuple(wavelet, dtype):
23-
return WaveletTuple(
24-
torch.tensor(wavelet.dec_lo).type(dtype),
25-
torch.tensor(wavelet.dec_hi).type(dtype),
26-
torch.tensor(wavelet.rec_lo).type(dtype),
27-
torch.tensor(wavelet.rec_hi).type(dtype),
28-
)
29-
30-
3113
def _to_jit_wavedec_2(data, wavelet):
3214
"""Ensure uniform datatypes in lists for the tracer.
3315
Going from List[Union[torch.Tensor, List[torch.Tensor]]] to List[torch.Tensor]
3416
means we have to stack the lists in the output.
3517
"""
3618
assert data.shape == (32, 1e3, 1e3), "Changing the chape requires re-tracing."
37-
coeff = ptwt.fswavedec2(data, wavelet, "reflect", level=5)
19+
coeff = ptwt.fswavedec2(data, wavelet, mode="reflect", level=5)
3820
coeff2 = []
3921
for c in coeff:
4022
if isinstance(c, torch.Tensor):
@@ -103,7 +85,7 @@ def _to_jit_wavedec_2(data, wavelet):
10385
end = time.perf_counter()
10486
ptwt_time_gpu.append(end - start)
10587

106-
wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
88+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
10789
jit_wavedec = torch.jit.trace(
10890
_to_jit_wavedec_2,
10991
(data.cuda(), wavelet),

examples/speed_tests/timeitconv_3d.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,6 @@
99
import ptwt
1010

1111

12-
class WaveletTuple(NamedTuple):
13-
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""
14-
15-
dec_lo: torch.Tensor
16-
dec_hi: torch.Tensor
17-
rec_lo: torch.Tensor
18-
rec_hi: torch.Tensor
19-
20-
21-
def _set_up_wavelet_tuple(wavelet, dtype):
22-
return WaveletTuple(
23-
torch.tensor(wavelet.dec_lo).type(dtype),
24-
torch.tensor(wavelet.dec_hi).type(dtype),
25-
torch.tensor(wavelet.rec_lo).type(dtype),
26-
torch.tensor(wavelet.rec_hi).type(dtype),
27-
)
28-
29-
3012
def _to_jit_wavedec_3(data, wavelet):
3113
"""Ensure uniform datatypes in lists for the tracer.
3214
@@ -85,7 +67,7 @@ def _to_jit_wavedec_3(data, wavelet):
8567
end = time.perf_counter()
8668
ptwt_time_gpu.append(end - start)
8769

88-
wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
70+
wavelet = ptwt.WaveletTensorTuple.from_wavelet(pywt.Wavelet("db5"), torch.float32)
8971
jit_wavedec = torch.jit.trace(
9072
_to_jit_wavedec_3,
9173
(data.cuda(), wavelet),

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ tests =
6767
# pooch is an optional scipy dependency for getting datasets
6868
pooch
6969
typing =
70-
mypy
70+
mypy @ git+https://github.com/python/mypy
7171
# needed otherwise pytest decorators don't get typed properly
7272
pytest
7373
examples =

src/ptwt/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Differentiable and gpu enabled fast wavelet transforms in PyTorch."""
22

3-
from ._util import Wavelet
3+
from ._util import Wavelet, WaveletTensorTuple
4+
from .constants import WaveletCoeff2d, WaveletCoeff2dSeparable, WaveletCoeffNd
45
from .continuous_transform import cwt
56
from .conv_transform import wavedec, waverec
67
from .conv_transform_2 import wavedec2, waverec2

0 commit comments

Comments
 (0)