Skip to content

Commit a95b4b5

Browse files
authored
Merge pull request #76 from v0lta/format-everything
Apply formatting everywhere
2 parents f46f7bc + b0ea2c3 commit a95b4b5

File tree

15 files changed

+313
-223
lines changed

15 files changed

+313
-223
lines changed

docs/conf.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
#
1313
import os
1414
import sys
15-
sys.path.insert(0, os.path.abspath('../src'))
15+
16+
sys.path.insert(0, os.path.abspath("../src"))
1617

1718

1819
# -- Project information -----------------------------------------------------
1920

20-
project = 'PyTorch-Wavelet-Toolbox'
21-
copyright = '2022, Moritz Wolter'
22-
author = 'Moritz Wolter'
21+
project = "PyTorch-Wavelet-Toolbox"
22+
copyright = "2022, Moritz Wolter"
23+
author = "Moritz Wolter"
2324

2425

2526
# -- General configuration ---------------------------------------------------
@@ -28,27 +29,27 @@
2829
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
2930
# ones.
3031
extensions = [
31-
'sphinx.ext.napoleon',
32-
'sphinx.ext.autodoc',
33-
'sphinx.ext.intersphinx',
34-
'sphinx.ext.todo',
35-
'sphinx.ext.coverage',
36-
'sphinx.ext.viewcode',
32+
"sphinx.ext.napoleon",
33+
"sphinx.ext.autodoc",
34+
"sphinx.ext.intersphinx",
35+
"sphinx.ext.todo",
36+
"sphinx.ext.coverage",
37+
"sphinx.ext.viewcode",
3738
]
3839

3940
napoleon_google_docstring = True
4041
# napoleon_use_admonition_for_examples = True
4142

4243
# Add any paths that contain templates here, relative to this directory.
43-
templates_path = ['_templates']
44+
templates_path = ["_templates"]
4445

4546
# List of patterns, relative to source directory, that match files and
4647
# directories to ignore when looking for source files.
4748
# This pattern also affects html_static_path and html_extra_path.
48-
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
49+
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
4950

5051
# document __init__ in the docpages
51-
autoclass_content = 'both'
52+
autoclass_content = "both"
5253

5354

5455
# -- Options for HTML output -------------------------------------------------
@@ -57,26 +58,26 @@
5758
# a list of builtin themes.
5859
#
5960

60-
html_theme = 'alabaster'
61+
html_theme = "alabaster"
6162
html_sidebars = {
62-
'**': [
63-
'about.html',
64-
'navigation.html',
65-
'searchbox.html',
63+
"**": [
64+
"about.html",
65+
"navigation.html",
66+
"searchbox.html",
6667
]
6768
}
6869

6970
html_theme_options = {
70-
'github_user': 'v0lta',
71-
'github_repo': 'PyTorch-Wavelet-Toolbox',
72-
'github_banner': 'false',
73-
'show_related': 'true',
74-
'page_width': 'auto',
75-
'sidebar_width': '250px'
71+
"github_user": "v0lta",
72+
"github_repo": "PyTorch-Wavelet-Toolbox",
73+
"github_banner": "false",
74+
"show_related": "true",
75+
"page_width": "auto",
76+
"sidebar_width": "250px",
7677
}
7778

7879

79-
html_favicon = 'favicon/favicon.ico'
80+
html_favicon = "favicon/favicon.ico"
8081

8182
# Add any paths that contain custom static files (such as style sheets) here,
8283
# relative to this directory. They are copied after the builtin static files,

docs/favicon/plot_shannon.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1-
import numpy as np
21
import matplotlib.pyplot as plt
2+
import numpy as np
33

4-
if __name__ == '__main__':
4+
if __name__ == "__main__":
55
bandwidth = 1
66
center = 0
77
grid_values = np.linspace(-6, 6, 500)
88
shannon = (
99
np.sqrt(bandwidth)
10-
* (
11-
np.sin(np.pi * bandwidth * grid_values)
12-
/ (np.pi * bandwidth * grid_values)
13-
)
10+
* (np.sin(np.pi * bandwidth * grid_values) / (np.pi * bandwidth * grid_values))
1411
* np.exp(1j * 2 * np.pi * center * grid_values)
1512
)
1613
plt.plot(shannon, linewidth=20.0)
1714
plt.axis("off")
1815
plt.savefig("shannon.png")
19-
plt.show()
16+
plt.show()

examples/continuous_signal_analysis/cwt_chirp_analysis.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import torch
2-
import numpy as np
3-
import ptwt
41
import matplotlib.pyplot as plt
2+
import numpy as np
53
import scipy.signal as signal
4+
import torch
5+
6+
import ptwt
67

78
if __name__ == "__main__":
89
t = np.linspace(-2, 2, 800, endpoint=False)

examples/deepfake_analysis/packet_plot.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import os
22
from itertools import product
3-
from tqdm import tqdm
43

5-
from PIL import Image
4+
import matplotlib.pyplot as plt
65
import numpy as np
7-
import torch
8-
import ptwt
96
import pywt
7+
import torch
8+
from PIL import Image
9+
from tqdm import tqdm
1010

11-
import matplotlib.pyplot as plt
11+
import ptwt
1212

1313

1414
def get_freq_order(level: int):
@@ -78,13 +78,13 @@ def generate_frequency_packet_image(packet_array: np.ndarray, degree: int):
7878

7979

8080
def load_image(path_to_file: str) -> torch.Tensor:
81-
8281
image = Image.open(path_to_file)
8382
tensor = torch.from_numpy(np.nan_to_num(np.array(image), posinf=255, neginf=0))
8483
return tensor
8584

85+
8686
def process_images(tensor: torch.Tensor, paths: list) -> torch.Tensor:
87-
tensor = torch.mean(tensor/255., -1)
87+
tensor = torch.mean(tensor / 255.0, -1)
8888
packets = ptwt.WaveletPacket2D(tensor, pywt.Wavelet("Haar"))
8989

9090
packet_list = []
@@ -103,10 +103,10 @@ def load_images(path: str) -> list:
103103
path = os.path.join(root, name)
104104
packets = load_image(path)
105105
image_list.append(packets)
106-
return image_list
106+
return image_list
107107

108108

109-
if __name__ == '__main__':
109+
if __name__ == "__main__":
110110
frequency_path, natural_path = get_freq_order(level=3)
111111
print("Loading ffhq images:")
112112
ffhq_images = load_images("./ffhq_style_gan/source_data/A_ffhq")
@@ -120,7 +120,6 @@ def load_images(path: str) -> list:
120120
del ffhq_images
121121
del ffhq_packets
122122

123-
124123
print("Loading style-gan images")
125124
gan_images = load_images("./ffhq_style_gan/source_data/B_stylegan")
126125
print("processing style-gan")
@@ -136,7 +135,7 @@ def load_images(path: str) -> list:
136135
plot_ffhq = generate_frequency_packet_image(mean_packets_ffhq, 3)
137136
plot_gan = generate_frequency_packet_image(mean_packets_gan, 3)
138137

139-
fig = plt.figure(figsize=(9,3))
138+
fig = plt.figure(figsize=(9, 3))
140139
fig.add_subplot(1, 2, 1)
141140
plt.imshow(plot_ffhq, vmax=1.5, vmin=-7)
142141
plt.title("real")
@@ -151,9 +150,9 @@ def load_images(path: str) -> list:
151150
plt.colorbar()
152151
plt.show()
153152

154-
plt.plot(torch.mean(mean_packets_ffhq, (1, 2)).flatten().numpy(), label='real')
155-
plt.plot(torch.mean(mean_packets_gan, (1, 2)).flatten().numpy(), label='fake')
156-
plt.xlabel('mean packets')
157-
plt.ylabel('magnitude')
153+
plt.plot(torch.mean(mean_packets_ffhq, (1, 2)).flatten().numpy(), label="real")
154+
plt.plot(torch.mean(mean_packets_gan, (1, 2)).flatten().numpy(), label="fake")
155+
plt.xlabel("mean packets")
156+
plt.ylabel("magnitude")
158157
plt.legend()
159158
plt.show()

examples/network_compression/mnist_compression.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
# based on https://github.com/pytorch/examples/blob/master/mnist/main.py
44

55
import argparse
6+
import collections
7+
8+
import matplotlib.pyplot as plt
69
import numpy as np
710
import torch
811
import torch.nn as nn
912
import torch.nn.functional as F
1013
import torch.optim as optim
11-
import collections
12-
import matplotlib.pyplot as plt
13-
from torchvision import datasets, transforms
1414
from torch.optim.lr_scheduler import StepLR
15-
from wavelet_linear import WaveletLayer
1615
from torch.utils.tensorboard.writer import SummaryWriter
16+
from torchvision import datasets, transforms
17+
from wavelet_linear import WaveletLayer
18+
1719
from ptwt.wavelets_learnable import ProductFilter
1820

1921

examples/network_compression/wavelet_linear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Originally created by moritz ([email protected])
22
# at https://github.com/v0lta/Wavelet-network-compression/blob/master/wavelet_learning/wavelet_linear.py
3-
import torch
43
import numpy as np
5-
from torch.nn.parameter import Parameter
64
import pywt
5+
import torch
6+
from torch.nn.parameter import Parameter
7+
78
from ptwt.conv_transform import wavedec, waverec
89

910

examples/speed_tests/timeitconv_1d.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
import pywt
2-
import ptwt
3-
import torch
4-
import numpy as np
51
import time
62
from typing import NamedTuple
73

84
import matplotlib.pyplot as plt
5+
import numpy as np
6+
import pywt
7+
import torch
8+
9+
import ptwt
10+
911

1012
class WaveletTuple(NamedTuple):
1113
"""Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""
@@ -24,15 +26,15 @@ def _set_up_wavelet_tuple(wavelet, dtype):
2426
torch.tensor(wavelet.rec_hi).type(dtype),
2527
)
2628

29+
2730
def _jit_wavedec_fun(data, wavelet):
2831
return ptwt.wavedec(data, wavelet, "periodic", level=10)
2932

3033

31-
if __name__ == '__main__':
34+
if __name__ == "__main__":
3235
length = 1e6
3336
repetitions = 100
3437

35-
3638
pywt_time_cpu = []
3739
ptwt_time_cpu = []
3840
ptwt_time_gpu = []
@@ -56,10 +58,10 @@ def _jit_wavedec_fun(data, wavelet):
5658

5759
wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
5860
jit_wavedec = torch.jit.trace(
59-
_jit_wavedec_fun,
60-
(data, wavelet),
61-
strict=False,
62-
)
61+
_jit_wavedec_fun,
62+
(data, wavelet),
63+
strict=False,
64+
)
6365

6466
for _ in range(repetitions):
6567
data = np.random.randn(32, int(length)).astype(np.float32)
@@ -69,7 +71,6 @@ def _jit_wavedec_fun(data, wavelet):
6971
end = time.perf_counter()
7072
ptwt_time_cpu_jit.append(end - start)
7173

72-
7374
for _ in range(repetitions):
7475
data = np.random.randn(32, int(length)).astype(np.float32)
7576
data = torch.from_numpy(data).cuda()
@@ -82,10 +83,10 @@ def _jit_wavedec_fun(data, wavelet):
8283

8384
wavelet = _set_up_wavelet_tuple(pywt.Wavelet("db5"), torch.float32)
8485
jit_wavedec = torch.jit.trace(
85-
_jit_wavedec_fun,
86-
(data.cuda(), wavelet),
87-
strict=False,
88-
)
86+
_jit_wavedec_fun,
87+
(data.cuda(), wavelet),
88+
strict=False,
89+
)
8990

9091
for _ in range(repetitions):
9192
data = np.random.randn(32, int(length)).astype(np.float32)
@@ -95,14 +96,24 @@ def _jit_wavedec_fun(data, wavelet):
9596
res = jit_wavedec(data, wavelet)
9697
torch.cuda.synchronize()
9798
end = time.perf_counter()
98-
ptwt_time_gpu_jit.append(end-start)
99+
ptwt_time_gpu_jit.append(end - start)
99100

100101
print("1d fwt results")
101-
print(f"1d-pywt-cpu :{np.mean(pywt_time_cpu):5.5f} +- {np.std(pywt_time_cpu):5.5f}")
102-
print(f"1d-ptwt-cpu :{np.mean(ptwt_time_cpu):5.5f} +- {np.std(ptwt_time_cpu):5.5f}")
103-
print(f"1d-ptwt-cpu-jit:{np.mean(ptwt_time_cpu_jit):5.5f} +- {np.std(ptwt_time_cpu_jit):5.5f}")
104-
print(f"1d-ptwt-gpu :{np.mean(ptwt_time_gpu):5.5f} +- {np.std(ptwt_time_gpu):5.5f}")
105-
print(f"1d-ptwt-gpu-jit:{np.mean(ptwt_time_gpu_jit):5.5f} +- {np.std(ptwt_time_gpu_jit):5.5f}")
102+
print(
103+
f"1d-pywt-cpu :{np.mean(pywt_time_cpu):5.5f} +- {np.std(pywt_time_cpu):5.5f}"
104+
)
105+
print(
106+
f"1d-ptwt-cpu :{np.mean(ptwt_time_cpu):5.5f} +- {np.std(ptwt_time_cpu):5.5f}"
107+
)
108+
print(
109+
f"1d-ptwt-cpu-jit:{np.mean(ptwt_time_cpu_jit):5.5f} +- {np.std(ptwt_time_cpu_jit):5.5f}"
110+
)
111+
print(
112+
f"1d-ptwt-gpu :{np.mean(ptwt_time_gpu):5.5f} +- {np.std(ptwt_time_gpu):5.5f}"
113+
)
114+
print(
115+
f"1d-ptwt-gpu-jit:{np.mean(ptwt_time_gpu_jit):5.5f} +- {np.std(ptwt_time_gpu_jit):5.5f}"
116+
)
106117
# plt.semilogy(pywt_time_cpu, label='pywt-cpu')
107118
# plt.semilogy(ptwt_time_cpu, label='ptwt-cpu')
108119
# plt.semilogy(ptwt_time_cpu_jit, label='ptwt-cpu-jit')
@@ -112,12 +123,24 @@ def _jit_wavedec_fun(data, wavelet):
112123
# plt.xlabel('repetition')
113124
# plt.ylabel('runtime [s]')
114125
# plt.show()
115-
time_stack = np.stack([pywt_time_cpu, ptwt_time_cpu, ptwt_time_cpu_jit, ptwt_time_gpu, ptwt_time_gpu_jit], -1)
126+
time_stack = np.stack(
127+
[
128+
pywt_time_cpu,
129+
ptwt_time_cpu,
130+
ptwt_time_cpu_jit,
131+
ptwt_time_gpu,
132+
ptwt_time_gpu_jit,
133+
],
134+
-1,
135+
)
116136
plt.boxplot(time_stack)
117-
plt.yscale('log')
118-
plt.xticks([1,2,3,4,5], ["pywt-cpu", "ptwt-cpu", "ptwt-cpu-jit", "ptwt-gpu", "ptwt-gpu-jit"])
137+
plt.yscale("log")
138+
plt.xticks(
139+
[1, 2, 3, 4, 5],
140+
["pywt-cpu", "ptwt-cpu", "ptwt-cpu-jit", "ptwt-gpu", "ptwt-gpu-jit"],
141+
)
119142
plt.xticks(rotation=20)
120-
plt.ylabel('runtime [s]')
121-
plt.title('DWT-1D')
122-
plt.savefig('./figs/timeitconv1d.png')
143+
plt.ylabel("runtime [s]")
144+
plt.title("DWT-1D")
145+
plt.savefig("./figs/timeitconv1d.png")
123146
# plt.show()

0 commit comments

Comments
 (0)