Skip to content

Commit 38901ce

Browse files
authored
Merge pull request #81 from v0lta/typing-everywhere
Type annotate the tests
2 parents f390042 + 6401034 commit 38901ce

23 files changed

+283
-183
lines changed

noxfile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def mypy(session):
4949
"--no-warn-return-any",
5050
"--explicit-package-bases",
5151
"src",
52+
"tests",
5253
)
5354

5455

src/ptwt/_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pywt
88
import torch
99

10-
from ptwt.constants import OrthogonalizeMethod
10+
from .constants import OrthogonalizeMethod
1111

1212

1313
class Wavelet(Protocol):

src/ptwt/continuous_transform.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,10 @@ def __call__(self, grid_values: torch.Tensor) -> torch.Tensor:
291291
shannon = (
292292
torch.sqrt(self.bandwidth)
293293
* (
294-
torch.sin(torch.pi * self.bandwidth * grid_values) # type: ignore
294+
torch.sin(torch.pi * self.bandwidth * grid_values)
295295
/ (torch.pi * self.bandwidth * grid_values)
296296
)
297-
* torch.exp(1j * 2 * torch.pi * self.center * grid_values) # type: ignore
297+
* torch.exp(1j * 2 * torch.pi * self.center * grid_values)
298298
)
299299
return shannon
300300

@@ -306,8 +306,8 @@ def __call__(self, grid_values: torch.Tensor) -> torch.Tensor:
306306
"""Return numerical values for the wavelet on a grid."""
307307
morlet = (
308308
1.0
309-
/ torch.sqrt(torch.pi * self.bandwidth) # type: ignore
309+
/ torch.sqrt(torch.pi * self.bandwidth)
310310
* torch.exp(-(grid_values**2) / self.bandwidth)
311-
* torch.exp(1j * 2 * torch.pi * self.center * grid_values) # type: ignore
311+
* torch.exp(1j * 2 * torch.pi * self.center * grid_values)
312312
)
313313
return morlet

src/ptwt/matmul_transform.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,11 @@ def orthogonalize(
157157
raise ValueError(f"Invalid orthogonalization method: {method}")
158158

159159

160-
class MatrixWavedec(object):
160+
class BaseMatrixWaveDec:
161+
"""A base class for matrix wavedec."""
162+
163+
164+
class MatrixWavedec(BaseMatrixWaveDec):
161165
"""Compute the sparse matrix fast wavelet transform.
162166
163167
Intermediate scale results must be divisible

src/ptwt/matmul_transform_2.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@
3030
_preprocess_tensor_dec2d,
3131
_waverec2d_fold_channels_2d_list,
3232
)
33-
from .matmul_transform import construct_boundary_a, construct_boundary_s, orthogonalize
33+
from .matmul_transform import (
34+
BaseMatrixWaveDec,
35+
construct_boundary_a,
36+
construct_boundary_s,
37+
orthogonalize,
38+
)
3439
from .sparse_math import (
3540
batch_mm,
3641
cat_sparse_identity_matrix,
@@ -217,7 +222,7 @@ def _matrix_pad_2(height: int, width: int) -> Tuple[int, int, Tuple[bool, bool]]
217222
return height, width, pad_tuple
218223

219224

220-
class MatrixWavedec2(object):
225+
class MatrixWavedec2(BaseMatrixWaveDec):
221226
"""Experimental sparse matrix 2d wavelet transform.
222227
223228
For a completely pad-free transform,

src/ptwt/matmul_transform_3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
wavelet: Union[Wavelet, str],
5959
level: Optional[int] = None,
6060
axes: Tuple[int, int, int] = (-3, -2, -1),
61-
boundary: Optional[str] = "qr",
61+
boundary: OrthogonalizeMethod = "qr",
6262
):
6363
"""Create a *separable* three-dimensional fast boundary wavelet transform.
6464
@@ -69,7 +69,7 @@ def __init__(
6969
wavelet (Union[Wavelet, str]): The wavelet to use.
7070
level (Optional[int]): The desired decomposition level.
7171
Defaults to None.
72-
boundary (Optional[str]): The matrix orthogonalization method.
72+
boundary: The matrix orthogonalization method.
7373
Defaults to "qr".
7474
7575
Raises:

src/ptwt/packets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(
106106
if len(data.shape) == 1:
107107
# add a batch dimension.
108108
data = data.unsqueeze(0)
109-
self.transform(data, maxlevel) # type: ignore
109+
self.transform(data, maxlevel)
110110
else:
111111
self.data = {}
112112

src/ptwt/separable_conv_transform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,9 @@ def _separable_conv_waverecn(
167167

168168
approx: torch.Tensor = coeffs[0]
169169
for level_dict in coeffs[1:]:
170-
keys = list(level_dict.keys())
171-
level_dict["a" * max(map(len, keys))] = approx
172-
approx = _separable_conv_idwtn(level_dict, wavelet)
170+
keys = list(level_dict.keys()) # type: ignore
171+
level_dict["a" * max(map(len, keys))] = approx # type: ignore
172+
approx = _separable_conv_idwtn(level_dict, wavelet) # type: ignore
173173
return approx
174174

175175

src/ptwt/sparse_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
from ptwt.constants import PaddingMode
9+
from .constants import PaddingMode
1010

1111

1212
def _dense_kron(

tests/_mackey_glass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _mackey(
5959
return x[:, discard:]
6060

6161

62-
class MackeyGenerator(object):
62+
class MackeyGenerator:
6363
"""Generates lorenz attractor data in 1 or 3d on the GPU."""
6464

6565
def __init__(

0 commit comments

Comments
 (0)