Skip to content

Commit 3967de6

Browse files
authored
Merge pull request #46 from v0lta/v0.1.5
merge v0.1.5
2 parents d426346 + 2e486fe commit 3967de6

32 files changed

+2227
-729
lines changed

.flake8

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,3 @@ import-order-style = pycharm
3636
application-import-names =
3737
ptwt
3838
tests
39-
format = ${cyan}%(path)s${reset}:${yellow_bold}%(row)d${reset}:${green_bold}%(col)d${reset}: ${red_bold}%(code)s${reset} %(text)s

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ htmlcov
55
.noseids
66
examples/ffhq_style_gan/
77
examples/ffhq_style_gan.zip
8+
examples/data
89

910

1011
# Byte-compiled / optimized / DLL files

CONTRIBUTING.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
## Contributing to the PyTorch-Wavelet-Toolbox
2+
3+
Contributions to the PyTorch-Wavelet-Toolbox are always welcome!
4+
5+
### Development workflow:
6+
We use nox to run our unit tests. Before creating your pull request follow these three steps.
7+
8+
1. Make sure all unit tests are passing.
9+
Run:
10+
``` bash
11+
nox -s test
12+
```
13+
to check.
14+
15+
2. Help yourself by running,
16+
``` bash
17+
nox -s format
18+
```
19+
to take care of linting issues, with an automatic fix.
20+
21+
3. Afterward, run
22+
``` bash
23+
nox -s lint
24+
```
25+
to learn where manual fixes are required for style compatability.

README.rst

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,27 @@ Pytorch Wavelet Toolbox (`ptwt`)
2626
:target: https://github.com/psf/black
2727
:alt: Black code style
2828

29+
.. image:: https://static.pepy.tech/personalized-badge/ptwt?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=Downloads
30+
:target: https://pepy.tech/project/ptwt
31+
32+
2933

3034

3135
Welcome to the PyTorch wavelet toolbox. This package implements:
3236

3337
- the fast wavelet transform (fwt) via ``wavedec`` and its inverse by providing the ``waverec`` function,
3438
- the two-dimensional fwt is called ``wavedec2`` the synthesis counterpart ``waverec2``,
3539
- ``wavedec3`` and ``waverec3`` cover the three-dimensional analysis and synthesis case,
36-
- ``MatrixWavedec`` and ``MatrixWaverec`` provide sparse-matrix-based fast wavelet transforms with boundary filters,
37-
- 2d sparse-matrix transforms with separable & non-separable boundary filters are available (experimental),
40+
- ``fswavedec2``, ``fswavedec3``, ``fswaverec2`` and ``fswaverec3`` support separable transformations.
41+
- ``MatrixWavedec`` and ``MatrixWaverec`` implement sparse-matrix-based fast wavelet transforms with boundary filters,
42+
- 2d sparse-matrix transforms with separable & non-separable boundary filters are available,
43+
- ``MatrixWavedec3`` and ``MatrixWaverec3`` allow separable 3D-fwt's with boundary filters.
3844
- ``cwt`` computes a one-dimensional continuous forward transform,
3945
- single and two-dimensional wavelet packet forward and backward transforms are available via the ``WaveletPacket`` and ``WaveletPacket2D`` objects,
4046
- finally, this package provides adaptive wavelet support (experimental).
4147

42-
This toolbox supports pywt-wavelets. Complete documentation is available:
43-
https://pytorch-wavelet-toolbox.readthedocs.io/
48+
This toolbox extends `PyWavelets <https://pywavelets.readthedocs.io/en/latest/>`_ . We additionally provide GPU and gradient support via a PyTorch backend.
49+
Complete documentation is available at: https://pytorch-wavelet-toolbox.readthedocs.io/
4450

4551

4652
**Installation**
@@ -101,7 +107,7 @@ To test an example run:
101107
import numpy as np
102108
import scipy.misc
103109
104-
face = np.transpose(scipy.misc.face(),
110+
face = np.transpose(scipy.datasets.face(),
105111
[2, 0, 1]).astype(np.float64)
106112
pytorch_face = torch.tensor(face)
107113
coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
@@ -143,8 +149,8 @@ Reconsidering the 1d case, try:
143149
144150
145151
The process for the 2d transforms ``MatrixWavedec2``, ``MatrixWaverec2`` works similarly.
146-
By default, a non-separable transformation is used.
147-
To use a separable transformation, pass ``separable=True`` to ``MatrixWavedec2`` and ``MatrixWaverec2``.
152+
By default, a separable transformation is used.
153+
To use a non-separable transformation, pass ``separable=False`` to ``MatrixWavedec2`` and ``MatrixWaverec2``.
148154
Separable transformations use a 1d transformation along both axes, which might be faster since fewer matrix entries
149155
have to be orthogonalized.
150156

@@ -162,15 +168,18 @@ See https://github.com/v0lta/PyTorch-Wavelet-Toolbox/tree/main/examples/network_
162168

163169
**Testing**
164170

165-
The ``tests`` folder contains multiple tests to allow independent verification of this toolbox. After cloning the
166-
repository, and moving into the main directory, and installing ``nox`` with ``pip install nox`` run:
171+
The ``tests`` folder contains multiple tests to allow independent verification of this toolbox.
172+
The GitHub workflow executes a subset of all tests for efficiency reasons.
173+
After cloning the repository, moving into the main directory, and installing ``nox`` with ``pip install nox`` run
167174

168175
.. code-block:: sh
169176
170177
$ nox --session test
171178
172179
173180
181+
to run all existing tests.
182+
174183
Citation
175184
""""""""
176185

docs/ptwt.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ ptwt.conv\_transform\_3 module
2727
:undoc-members:
2828
:show-inheritance:
2929

30+
ptwt.separable\_conv\_transform module
31+
--------------------------------------
32+
33+
.. automodule:: ptwt.separable_conv_transform
34+
:members:
35+
:undoc-members:
36+
:show-inheritance:
3037

3138
ptwt.continuous\_transform module
3239
---------------------------------
@@ -60,6 +67,15 @@ ptwt.matmul\_transform\_2 module
6067
:undoc-members:
6168
:show-inheritance:
6269

70+
ptwt.matmul\_transform\_3 module
71+
--------------------------------
72+
73+
.. automodule:: ptwt.matmul_transform_3
74+
:members:
75+
:undoc-members:
76+
:show-inheritance:
77+
78+
6379
ptwt.sparse\_math module
6480
------------------------
6581

examples/deepfake_analysis/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@ Recreating these experiments requires roughly 8GB of free disc space and 10GB of
2525
To reproduce these plots:
2626
1. Download [ffhq_style_gan.zip](https://drive.google.com/uc?id=1MOHKuEVqURfCKAN9dwp1o2tuR19OTQCF&export=download) and
2727
2. Extract the image pairs here.
28-
3. Run `python packet_plot.py`
28+
3. Check the file structure. In `ffhq_style_gan` the folder structure should be:
29+
```
30+
source_data
31+
├── A_ffhq
32+
├── B_stylegan
33+
```
34+
4. Run `python packet_plot.py`
2935

3036
You can read more about GAN-detection in the paper [Wavelet-Packets for Deepfake Image Analysis and Detection](https://rdcu.be/cUIRt).
3137
A complete project building gan-detectors on top of wavelets is available at:
38.6 KB
Loading

examples/wavelet_packet_chirp_analysis/chirp_analysis.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,40 @@
77
# use from src.ptwt.packets if you cloned the repo instead of using pip.
88
from ptwt import WaveletPacket
99

10-
t = np.linspace(0, 10, 1500)
11-
w = scipy.signal.chirp(t, f0=1, f1=50, t1=10, method="linear")
10+
fs = 1000
11+
t = np.linspace(0, 2, int(2//(1/fs)))
12+
w = np.sin(256*np.pi*t**2)
1213

13-
wavelet = pywt.Wavelet("db3")
14+
wavelet = pywt.Wavelet("sym8")
1415
wp = WaveletPacket(
15-
data=torch.from_numpy(w.astype(np.float32)), wavelet=wavelet, mode="reflect"
16+
data=torch.from_numpy(w.astype(np.float32)), wavelet=wavelet, mode="boundary"
1617
)
17-
nodes = wp.get_level(5)
18+
level = 5
19+
nodes = wp.get_level(level)
1820
np_lst = []
1921
for node in nodes:
2022
np_lst.append(wp[node])
2123
viz = np.stack(np_lst).squeeze()
2224

25+
n = list(range(int(np.power(2, level))))
26+
freqs = (fs/2)*(n/(np.power(2, level)))
27+
28+
xticks = list(range(viz.shape[-1]))[::6]
29+
xlabels = np.round(np.linspace(min(t), max(t), viz.shape[-1]), 2)[::6]
30+
2331
fig, axs = plt.subplots(2)
2432
axs[0].plot(t, w)
25-
axs[0].set_title("Linear Chirp, f(0)=1, f(10)=50")
26-
axs[0].set_xlabel("t [s]")
33+
axs[0].set_title("Analyzed signal")
34+
axs[0].set_xlabel("time [s]")
35+
axs[0].set_ylabel("magnitude")
2736

28-
axs[1].set_title("Wavelet analysis")
37+
axs[1].set_title("Wavelet packet analysis")
2938
axs[1].imshow(np.abs(viz))
30-
axs[1].set_xlabel("time")
31-
axs[1].set_ylabel("frequency")
39+
axs[1].set_xlabel("time [s]")
40+
axs[1].set_xticks(xticks)
41+
axs[1].set_xticklabels(xlabels)
42+
axs[1].set_ylabel("frequency [Hz]")
43+
axs[1].set_yticks(n[::4])
44+
axs[1].set_yticklabels(freqs[::4])
3245
axs[1].invert_yaxis()
3346
plt.show()

noxfile.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ def run_test_fast(session):
2121
@nox.session(name="lint")
2222
def lint(session):
2323
"""Check code conventions."""
24-
session.install("flake8==4.0.1")
24+
session.install("flake8")
2525
session.install(
26-
"flake8-colors",
2726
"flake8-black",
2827
"flake8-docstrings",
2928
"flake8-bugbear",
@@ -32,7 +31,6 @@ def lint(session):
3231
"pydocstyle",
3332
"darglint",
3433
)
35-
session.install("flake8-bandit==2.1.2", "bandit==1.7.2")
3634
session.run("flake8", "src", "tests", "noxfile.py")
3735

3836

@@ -48,8 +46,7 @@ def mypy(session):
4846
"--ignore-missing-imports",
4947
"--strict",
5048
"--no-warn-return-any",
51-
"--implicit-reexport",
52-
"--allow-untyped-calls",
49+
"--explicit-package-bases",
5350
"src",
5451
)
5552

setup.cfg

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,12 @@ classifiers =
4141
install_requires =
4242
PyWavelets
4343
torch
44-
scipy
44+
scipy>=1.10
45+
pooch
4546
matplotlib
4647
numpy
48+
pytest
49+
nox
4750

4851
packages = find:
4952
package_dir =
@@ -58,4 +61,4 @@ where = src
5861
##########################
5962
[darglint]
6063
docstring_style = google
61-
strictness = short
64+
strictness = short

0 commit comments

Comments
 (0)