Skip to content

Commit f390042

Browse files
authored
Use typing.Literal for boundary mode, padding mode, and orthogonalization mode (#77)
1 parent 6565e60 commit f390042

18 files changed

+227
-164
lines changed

docs/ptwt.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,9 @@ ptwt.wavelets\_learnable module
102102
:undoc-members:
103103
:show-inheritance:
104104

105-
105+
ptwt.constants
106+
-------------------------------
107+
.. automodule:: ptwt.constants
108+
:members:
109+
:undoc-members:
110+
:show-inheritance:

examples/speed_tests/timeitconv_2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _to_jit_wavedec_2(data, wavelet):
3333
means we have to stack the lists in the output.
3434
"""
3535
assert data.shape == (32, 1e3, 1e3), "Changing the chape requires re-tracing."
36-
coeff = ptwt.wavedec2(data, wavelet, "periodic", level=5)
36+
coeff = ptwt.wavedec2(data, wavelet, mode="periodic", level=5)
3737
coeff2 = []
3838
for c in coeff:
3939
if isinstance(c, torch.Tensor):

src/ptwt/_util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""Utility methods to compute wavelet decompositions from a dataset."""
22

3+
import typing
34
from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple, Union
45

56
import numpy as np
67
import pywt
78
import torch
89

10+
from ptwt.constants import OrthogonalizeMethod
11+
912

1013
class Wavelet(Protocol):
1114
"""Wavelet object interface, based on the pywt wavelet object."""
@@ -43,8 +46,8 @@ def _as_wavelet(wavelet: Union[Wavelet, str]) -> Wavelet:
4346
return wavelet
4447

4548

46-
def _is_boundary_mode_supported(boundary_mode: Optional[str]) -> bool:
47-
return boundary_mode in ["qr", "gramschmidt"]
49+
def _is_boundary_mode_supported(boundary_mode: Optional[OrthogonalizeMethod]) -> bool:
50+
return boundary_mode in typing.get_args(OrthogonalizeMethod)
4851

4952

5053
def _is_dtype_supported(dtype: torch.dtype) -> bool:

src/ptwt/constants.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""Constants and types used throughout the PyTorch Wavelet Toolbox."""
2+
3+
from typing import Literal, Union
4+
5+
__all__ = [
6+
"BoundaryMode",
7+
"ExtendedBoundaryMode",
8+
"PaddingMode",
9+
"OrthogonalizeMethod",
10+
]
11+
12+
BoundaryMode = Literal["constant", "zero", "reflect", "periodic", "symmetric"]
13+
"""
14+
This is a type literal for the way of padding.
15+
16+
- Refection padding mirrors samples along the border.
17+
- Zero padding pads zeros.
18+
- Constant padding replicates border values.
19+
- Periodic padding cyclically repeats samples.
20+
- Symmetric padding mirrors samples along the border
21+
"""
22+
23+
ExtendedBoundaryMode = Union[Literal["boundary"], BoundaryMode]
24+
25+
PaddingMode = Literal["full", "valid", "same", "sameshift"]
26+
"""
27+
The padding mode is used when construction convolution matrices.
28+
"""
29+
30+
OrthogonalizeMethod = Literal["qr", "gramschmidt"]
31+
"""
32+
The method for orthogonalizing a matrix.
33+
34+
1. 'qr' relies on pytorch's dense qr implementation, it is fast but memory hungry.
35+
2. 'gramschmidt' option is sparse, memory efficient, and slow.
36+
37+
Choose 'gramschmidt' if 'qr' runs out of memory.
38+
"""

src/ptwt/conv_transform.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
# Created by moritz wolter, 14.04.20
7-
from typing import List, Optional, Sequence, Tuple, Union
7+
from typing import List, Optional, Sequence, Tuple, Union, cast
88

99
import pywt
1010
import torch
@@ -18,6 +18,7 @@
1818
_pad_symmetric,
1919
_unfold_axes,
2020
)
21+
from .constants import BoundaryMode
2122

2223

2324
def _create_tensor(
@@ -106,7 +107,7 @@ def _get_pad(data_len: int, filt_len: int) -> Tuple[int, int]:
106107
return padr, padl
107108

108109

109-
def _translate_boundary_strings(pywt_mode: str) -> str:
110+
def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str:
110111
"""Translate pywt mode strings to PyTorch mode strings.
111112
112113
We support constant, zero, reflect, and periodic.
@@ -118,24 +119,25 @@ def _translate_boundary_strings(pywt_mode: str) -> str:
118119
119120
"""
120121
if pywt_mode == "constant":
121-
pt_mode = "replicate"
122+
return "replicate"
122123
elif pywt_mode == "zero":
123-
pt_mode = "constant"
124+
return "constant"
124125
elif pywt_mode == "reflect":
125-
pt_mode = pywt_mode
126+
return pywt_mode
126127
elif pywt_mode == "periodic":
127-
pt_mode = "circular"
128+
return "circular"
128129
elif pywt_mode == "symmetric":
129130
# pytorch does not support symmetric mode,
130131
# we have our own implementation.
131-
pt_mode = pywt_mode
132-
else:
133-
raise ValueError("Padding mode not supported.")
134-
return pt_mode
132+
return pywt_mode
133+
raise ValueError(f"Padding mode not supported: {pywt_mode}")
135134

136135

137136
def _fwt_pad(
138-
data: torch.Tensor, wavelet: Union[Wavelet, str], mode: str = "reflect"
137+
data: torch.Tensor,
138+
wavelet: Union[Wavelet, str],
139+
*,
140+
mode: Optional[BoundaryMode] = None,
139141
) -> torch.Tensor:
140142
"""Pad the input signal to make the fwt matrix work.
141143
@@ -145,29 +147,26 @@ def _fwt_pad(
145147
data (torch.Tensor): Input data ``[batch_size, 1, time]``
146148
wavelet (Wavelet or str): A pywt wavelet compatible object or
147149
the name of a pywt wavelet.
148-
mode (str): The desired way to pad. The following methods are supported::
149-
150-
"reflect", "zero", "constant", "periodic", "symmetric".
151-
152-
Refection padding mirrors samples along the border.
153-
Zero padding pads zeros.
154-
Constant padding replicates border values.
155-
Periodic padding cyclically repeats samples.
156-
This function defaults to reflect.
150+
mode :
151+
The desired padding mode for extending the signal along the edges.
152+
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
157153
158154
Returns:
159155
torch.Tensor: A PyTorch tensor with the padded input data
160156
161157
"""
162158
wavelet = _as_wavelet(wavelet)
159+
163160
# convert pywt to pytorch convention.
164-
mode = _translate_boundary_strings(mode)
161+
if mode is None:
162+
mode = cast(BoundaryMode, "reflect")
163+
pytorch_mode = _translate_boundary_strings(mode)
165164

166165
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
167-
if mode == "symmetric":
166+
if pytorch_mode == "symmetric":
168167
data_pad = _pad_symmetric(data, [(padl, padr)])
169168
else:
170-
data_pad = torch.nn.functional.pad(data, [padl, padr], mode=mode)
169+
data_pad = torch.nn.functional.pad(data, [padl, padr], mode=pytorch_mode)
171170
return data_pad
172171

173172

@@ -263,7 +262,8 @@ def _preprocess_result_list_rec1d(
263262
def wavedec(
264263
data: torch.Tensor,
265264
wavelet: Union[Wavelet, str],
266-
mode: str = "reflect",
265+
*,
266+
mode: BoundaryMode = "reflect",
267267
level: Optional[int] = None,
268268
axis: int = -1,
269269
) -> List[torch.Tensor]:
@@ -276,18 +276,9 @@ def wavedec(
276276
the name of a pywt wavelet.
277277
Please consider the output from ``pywt.wavelist(kind='discrete')``
278278
for possible choices.
279-
mode (str): The desired padding mode. Padding extends the signal along
280-
the edges. Supported methods are::
281-
282-
"reflect", "zero", "constant", "periodic", "symmetric".
283-
284-
Defaults to "reflect".
285-
286-
Symmetric padding mirrors samples along the border.
287-
Refection padding reflects samples along the border.
288-
Zero padding pads zeros.
289-
Constant padding replicates border values.
290-
Periodic padding cyclically repeats samples.
279+
mode :
280+
The desired padding mode for extending the signal along the edges.
281+
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
291282
level (int): The scale level to be computed.
292283
Defaults to None.
293284
axis (int): Compute the transform over this axis instead of the

src/ptwt/conv_transform_2.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from functools import partial
8-
from typing import List, Optional, Tuple, Union
8+
from typing import List, Optional, Tuple, Union, cast
99

1010
import pywt
1111
import torch
@@ -25,6 +25,7 @@
2525
_undo_swap_axes,
2626
_unfold_axes,
2727
)
28+
from .constants import BoundaryMode
2829
from .conv_transform import (
2930
_adjust_padding_at_reconstruction,
3031
_get_filter_tensors,
@@ -58,7 +59,10 @@ def _construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
5859

5960

6061
def _fwt_pad2(
61-
data: torch.Tensor, wavelet: Union[Wavelet, str], mode: str = "reflect"
62+
data: torch.Tensor,
63+
wavelet: Union[Wavelet, str],
64+
*,
65+
mode: Optional[BoundaryMode] = None,
6266
) -> torch.Tensor:
6367
"""Pad data for the 2d FWT.
6468
@@ -68,25 +72,26 @@ def _fwt_pad2(
6872
data (torch.Tensor): Input data with 4 dimensions.
6973
wavelet (Wavelet or str): A pywt wavelet compatible object or
7074
the name of a pywt wavelet.
71-
mode (str): The padding mode.
72-
Supported modes are::
73-
74-
"reflect", "zero", "constant", "periodic", "symmetric".
75-
76-
"reflect" is the default mode.
75+
mode :
76+
The desired padding mode for extending the signal along the edges.
77+
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
7778
7879
Returns:
7980
The padded output tensor.
8081
8182
"""
82-
mode = _translate_boundary_strings(mode)
83+
if mode is None:
84+
mode = cast(BoundaryMode, "reflect")
85+
pytorch_mode = _translate_boundary_strings(mode)
8386
wavelet = _as_wavelet(wavelet)
8487
padb, padt = _get_pad(data.shape[-2], _get_len(wavelet))
8588
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
86-
if mode == "symmetric":
89+
if pytorch_mode == "symmetric":
8790
data_pad = _pad_symmetric(data, [(padt, padb), (padl, padr)])
8891
else:
89-
data_pad = torch.nn.functional.pad(data, [padl, padr, padt, padb], mode=mode)
92+
data_pad = torch.nn.functional.pad(
93+
data, [padl, padr, padt, padb], mode=pytorch_mode
94+
)
9095
return data_pad
9196

9297

@@ -122,7 +127,8 @@ def _preprocess_tensor_dec2d(
122127
def wavedec2(
123128
data: torch.Tensor,
124129
wavelet: Union[Wavelet, str],
125-
mode: str = "reflect",
130+
*,
131+
mode: BoundaryMode = "reflect",
126132
level: Optional[int] = None,
127133
axes: Tuple[int, int] = (-2, -1),
128134
) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
@@ -140,11 +146,9 @@ def wavedec2(
140146
wavelet (Wavelet or str): A pywt wavelet compatible object or
141147
the name of a pywt wavelet. Refer to the output of
142148
``pywt.wavelist(kind="discrete")`` for a list of possible choices.
143-
mode (str): The padding mode. Options are::
144-
145-
"reflect", "zero", "constant", "periodic", "symmetric".
146-
147-
This function defaults to "reflect".
149+
mode :
150+
The desired padding mode for extending the signal along the edges.
151+
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
148152
level (int): The number of desired scales.
149153
Defaults to None.
150154
axes (Tuple[int, int]): Compute the transform over these axes instead of the

src/ptwt/conv_transform_3.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_undo_swap_axes,
2525
_unfold_axes,
2626
)
27+
from .constants import BoundaryMode
2728
from .conv_transform import (
2829
_adjust_padding_at_reconstruction,
2930
_get_filter_tensors,
@@ -63,7 +64,7 @@ def _construct_3d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
6364

6465

6566
def _fwt_pad3(
66-
data: torch.Tensor, wavelet: Union[Wavelet, str], mode: str
67+
data: torch.Tensor, wavelet: Union[Wavelet, str], *, mode: BoundaryMode
6768
) -> torch.Tensor:
6869
"""Pad data for the 3d-FWT.
6970
@@ -73,37 +74,38 @@ def _fwt_pad3(
7374
data (torch.Tensor): Input data with 4 dimensions.
7475
wavelet (Wavelet or str): A pywt wavelet compatible object or
7576
the name of a pywt wavelet.
76-
mode (str): The padding mode. Supported modes are::
77-
78-
"reflect", "zero", "constant", "periodic", "symmetric".
77+
mode :
78+
The desired padding mode for extending the signal along the edges.
79+
See :data:`ptwt.constants.BoundaryMode`.
7980
8081
Returns:
8182
The padded output tensor.
8283
8384
"""
84-
mode = _translate_boundary_strings(mode)
85+
pytorch_mode = _translate_boundary_strings(mode)
8586

8687
wavelet = _as_wavelet(wavelet)
8788
pad_back, pad_front = _get_pad(data.shape[-3], _get_len(wavelet))
8889
pad_bottom, pad_top = _get_pad(data.shape[-2], _get_len(wavelet))
8990
pad_right, pad_left = _get_pad(data.shape[-1], _get_len(wavelet))
90-
if mode == "symmetric":
91+
if pytorch_mode == "symmetric":
9192
data_pad = _pad_symmetric(
9293
data, [(pad_front, pad_back), (pad_top, pad_bottom), (pad_left, pad_right)]
9394
)
9495
else:
9596
data_pad = torch.nn.functional.pad(
9697
data,
9798
[pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back],
98-
mode=mode,
99+
mode=pytorch_mode,
99100
)
100101
return data_pad
101102

102103

103104
def wavedec3(
104105
data: torch.Tensor,
105106
wavelet: Union[Wavelet, str],
106-
mode: str = "zero",
107+
*,
108+
mode: BoundaryMode = "zero",
107109
level: Optional[int] = None,
108110
axes: Tuple[int, int, int] = (-3, -2, -1),
109111
) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]:
@@ -114,11 +116,9 @@ def wavedec3(
114116
[batch_size, length, height, width]
115117
wavelet (Union[Wavelet, str]): The wavelet to transform with.
116118
``pywt.wavelist(kind='discrete')`` lists possible choices.
117-
mode (str): The padding mode. Possible options are::
118-
119-
"reflect", "zero", "constant", "periodic", "symmetric".
120-
121-
Defaults to "zero".
119+
mode :
120+
The desired padding mode for extending the signal along the edges.
121+
Defaults to "zero". See :data:`ptwt.constants.BoundaryMode`.
122122
level (Optional[int]): The maximum decomposition level.
123123
This argument defaults to None.
124124
axes (Tuple[int, int, int]): Compute the transform over these axes

0 commit comments

Comments
 (0)