Skip to content

Commit 8d88406

Browse files
authored
Merge pull request #109 from v0lta/add-pytyped
pyptoject.toml
2 parents 2607c98 + 59a86a1 commit 8d88406

26 files changed

+201
-206
lines changed

.github/workflows/tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
strategy:
1414
matrix:
1515
os: [ ubuntu-latest ]
16-
python-version: [3.9, 3.12]
16+
python-version: [3.11, 3.12]
1717
steps:
1818
- uses: actions/checkout@v2
1919
- name: Set up Python ${{ matrix.python-version }}
@@ -30,7 +30,7 @@ jobs:
3030
runs-on: ubuntu-latest
3131
strategy:
3232
matrix:
33-
python-version: [3.9, 3.12]
33+
python-version: [3.11, 3.12]
3434
steps:
3535
- uses: actions/checkout@v2
3636
- name: Set up Python ${{ matrix.python-version }}
@@ -46,7 +46,7 @@ jobs:
4646
runs-on: ubuntu-latest
4747
strategy:
4848
matrix:
49-
python-version: [3.9, 3.12]
49+
python-version: [3.11, 3.12]
5050
steps:
5151
- uses: actions/checkout@v2
5252
- name: Set up Python ${{ matrix.python-version }}

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ coverage_html_report/
44
htmlcov
55
.noseids
66
log/
7+
.virtual_python
78
examples/deepfake_analysis/ffhq_style_gan/
89

910

noxfile.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def format(session):
7878
"""Fix common convention problems automatically."""
7979
session.install("black")
8080
session.install("isort")
81-
session.run("isort", ".")
82-
session.run("black", ".")
81+
session.run("isort", "src", "tests", "noxfile.py")
82+
session.run("black", "src", "tests", "noxfile.py")
8383

8484

8585
@nox.session(name="coverage")
@@ -100,27 +100,6 @@ def clean_coverage(session):
100100
session.run("rm", "-r", "htmlcov", external=True)
101101

102102

103-
@nox.session(name="build")
104-
def build(session):
105-
"""Build a pip package."""
106-
session.install("wheel")
107-
session.install("setuptools")
108-
session.run("python", "setup.py", "-q", "sdist", "bdist_wheel")
109-
110-
111-
@nox.session(name="finish")
112-
def finish(session):
113-
"""Finish this version increase the version number and upload to pypi."""
114-
session.install("bump2version")
115-
session.install("twine")
116-
session.run("bumpversion", "release", external=True)
117-
build(session)
118-
session.run("twine", "upload", "--skip-existing", "dist/*", external=True)
119-
session.run("git", "push", external=True)
120-
session.run("bumpversion", "patch", external=True)
121-
session.run("git", "push", external=True)
122-
123-
124103
@nox.session(name="check-package")
125104
def pyroma(session):
126105
"""Run pyroma to check if the package is ok."""

pyproject.toml

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
[tool.pdm]
2+
package-dir = "src"
3+
4+
[project]
5+
name = "ptwt"
6+
description = "Differentiable and gpu enabled fast wavelet transforms in PyTorch"
7+
version = "0.1.10-dev"
8+
keywords = ["Wavelets", "Wavelet Transform", "Fast Wavelet Transform", "Boundary Wavelets", "PyTorch"]
9+
readme = "README.rst"
10+
authors = [
11+
{name = "Moritz Wolter and Felix Blanke", email = "[email protected]"},
12+
]
13+
maintainers = [
14+
{name = "Moritz Wolter and Felix Blanke", email = "[email protected]"},
15+
]
16+
classifiers = [
17+
"Development Status :: 4 - Beta",
18+
"Environment :: Console",
19+
"Intended Audience :: Science/Research",
20+
"License :: OSI Approved :: European Union Public Licence 1.2 (EUPL 1.2)",
21+
"Operating System :: OS Independent",
22+
"Programming Language :: Python",
23+
"Programming Language :: Python :: 3 :: Only",
24+
"Programming Language :: Python :: 3.12",
25+
"Programming Language :: Python :: 3.11",
26+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
27+
]
28+
requires-python = ">=3.11"
29+
dependencies = [
30+
"PyWavelets",
31+
"numpy",
32+
"torch",
33+
]
34+
license = {text = "EUPL-1.2"}
35+
36+
[project.urls]
37+
Homepage = "https://github.com/v0lta/PyTorch-Wavelet-Toolbox"
38+
Downloads = "https://github.com/v0lta/PyTorch-Wavelet-Toolbox/releases"
39+
"Bug Tracker" = "https://github.com/v0lta/PyTorch-Wavelet-Toolbox/issues"
40+
"Source Code" = "https://github.com/v0lta/PyTorch-Wavelet-Toolbox"
41+
42+
[project.optional-dependencies]
43+
tests = [
44+
"pooch",
45+
"pytest",
46+
"scipy>=1.10",
47+
]
48+
typing = [
49+
"mypy>=1.11",
50+
"pytest",
51+
]
52+
examples = [
53+
"matplotlib",
54+
]
55+
56+
[build-system]
57+
requires = ["pdm-backend"]
58+
build-backend = "pdm.backend"

setup.cfg

Lines changed: 0 additions & 84 deletions
This file was deleted.

setup.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/ptwt/_util.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,11 @@ def rename_kwargs(
758758
kwargs: Param.kwargs, # type: ignore
759759
aliases: dict[str, str],
760760
) -> None:
761-
"""Rename deprecated kwarg."""
761+
"""Rename deprecated kwarg.
762+
763+
Raises:
764+
TypeError: If both arguments are present.
765+
"""
762766
for alias, new in aliases.items():
763767
if alias in kwargs:
764768
if new in kwargs:

src/ptwt/continuous_transform.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from pywt._functions import scale2frequency
1212
from torch.fft import fft, ifft
1313

14+
__all__ = ["cwt"]
15+
1416

1517
def _next_fast_len(n: int) -> int:
1618
"""Round up size to the nearest power of two.
@@ -24,10 +26,10 @@ def _next_fast_len(n: int) -> int:
2426

2527
def cwt(
2628
data: torch.Tensor,
27-
scales: Union[np.ndarray, torch.Tensor], # type: ignore
29+
scales: Union[np.ndarray, torch.Tensor],
2830
wavelet: Union[ContinuousWavelet, str],
2931
sampling_period: float = 1.0,
30-
) -> tuple[torch.Tensor, np.ndarray]: # type: ignore
32+
) -> tuple[torch.Tensor, np.ndarray]:
3133
"""Compute the single-dimensional continuous wavelet transform.
3234
3335
This function is a PyTorch port of pywt.cwt as found at:
@@ -185,11 +187,11 @@ def _integrate_wavelet(
185187
"""
186188

187189
def _integrate(
188-
arr: Union[np.ndarray, torch.Tensor], # type: ignore
189-
step: Union[np.ndarray, torch.Tensor], # type: ignore
190-
) -> Union[np.ndarray, torch.Tensor]: # type: ignore
190+
arr: Union[np.ndarray, torch.Tensor],
191+
step: Union[np.ndarray, torch.Tensor],
192+
) -> Union[np.ndarray, torch.Tensor]:
191193
if type(arr) is np.ndarray:
192-
integral = np.cumsum(arr)
194+
integral: Any = np.cumsum(arr)
193195
elif type(arr) is torch.Tensor:
194196
integral = torch.cumsum(arr, -1)
195197
else:
@@ -212,12 +214,12 @@ def _integrate(
212214
return _integrate(psi, step), x
213215

214216
elif len(functions_approximations) == 3: # orthogonal wavelet
215-
_, psi, x = functions_approximations
217+
_, psi, x = functions_approximations # type: ignore
216218
step = x[1] - x[0]
217219
return _integrate(psi, step), x
218220

219221
else: # biorthogonal wavelet
220-
_, psi_d, _, psi_r, x = functions_approximations
222+
_, psi_d, _, psi_r, x = functions_approximations # type: ignore
221223
step = x[1] - x[0]
222224
return _integrate(psi_d, step), _integrate(psi_r, step), x
223225

@@ -248,7 +250,11 @@ def __init__(self, name: str):
248250
)
249251

250252
def __call__(self, grid_values: torch.Tensor) -> torch.Tensor:
251-
"""Return numerical values for the wavelet on a grid."""
253+
"""Return numerical values for the wavelet on a grid.
254+
255+
Raises:
256+
NotImplementedError: If this call is not overwritten by a child.
257+
"""
252258
raise NotImplementedError
253259

254260
@property

src/ptwt/conv_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
)
2727
from .constants import BoundaryMode, Wavelet, WaveletCoeff1d
2828

29+
__all__ = ["wavedec", "waverec"]
30+
2931

3032
def _fwt_pad(
3133
data: torch.Tensor,

src/ptwt/conv_transform_2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
)
2828
from .constants import BoundaryMode, Wavelet, WaveletCoeff2d, WaveletDetailTuple2d
2929

30+
__all__ = ["wavedec2", "waverec2"]
31+
3032

3133
def _construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
3234
"""Construct two-dimensional filters using outer products.

0 commit comments

Comments
 (0)