Skip to content

Commit cf9793d

Browse files
committed
labels for the frequency axis.
1 parent fcc4845 commit cf9793d

File tree

3 files changed

+36
-6
lines changed

3 files changed

+36
-6
lines changed
-17.4 KB
Loading

examples/wavelet_packet_chirp_analysis/chirp_analysis.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,36 @@
1111
t = np.linspace(0, 2, int(2//(1/fs)))
1212
w = np.sin(256*np.pi*t**2)
1313

14-
wavelet = pywt.Wavelet("sym6")
14+
wavelet = pywt.Wavelet("sym8")
1515
wp = WaveletPacket(
1616
data=torch.from_numpy(w.astype(np.float32)), wavelet=wavelet, mode="boundary"
1717
)
18-
nodes = wp.get_level(5)
18+
level = 5
19+
nodes = wp.get_level(level)
1920
np_lst = []
2021
for node in nodes:
2122
np_lst.append(wp[node])
2223
viz = np.stack(np_lst).squeeze()
2324

25+
n = list(range(int(np.power(2, level))))
26+
freqs = (fs/2)*(n/(np.power(2, level)))
27+
28+
xticks = list(range(viz.shape[-1]))[::6]
29+
xlabels = np.round(np.linspace(min(t), max(t), viz.shape[-1]), 2)[::6]
30+
2431
fig, axs = plt.subplots(2)
2532
axs[0].plot(t, w)
2633
axs[0].set_title("Analyzed signal")
27-
axs[0].set_xlabel("t [s]")
34+
axs[0].set_xlabel("time [s]")
35+
axs[0].set_ylabel("magnitude")
2836

29-
axs[1].set_title("Wavelet analysis")
37+
axs[1].set_title("Wavelet packet analysis")
3038
axs[1].imshow(np.abs(viz))
31-
axs[1].set_xlabel("time")
32-
axs[1].set_ylabel("frequency")
39+
axs[1].set_xlabel("time [s]")
40+
axs[1].set_xticks(xticks)
41+
axs[1].set_xticklabels(xlabels)
42+
axs[1].set_ylabel("frequency [Hz]")
43+
axs[1].set_yticks(n[::4])
44+
axs[1].set_yticklabels(freqs[::4])
3345
axs[1].invert_yaxis()
3446
plt.show()

src/ptwt/packets.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import pywt
1010
import torch
11+
import numpy as np
1112

1213
from ._util import Wavelet, _as_wavelet
1314
from .conv_transform import wavedec, waverec
@@ -22,6 +23,23 @@
2223
BaseDict = collections.UserDict
2324

2425

26+
def _wpfreq(fs: float, level: int) -> np.ndarray:
27+
"""Compute the frequencies for a fully decomposed single dimensional
28+
packet tree. The packet transform linearly subdivides all frequencies
29+
from zero up to the Nyquist frequency.
30+
31+
Args:
32+
fs (float): The sampling frequency
33+
level (int): The decomposition level
34+
35+
Returns:
36+
np.ndarray: The frequency bins of the packets in frequency order.
37+
"""
38+
n = list(range(int(np.power(2., level))))
39+
freqs = (fs/2.)*(n/(np.power(2., level)))
40+
return freqs
41+
42+
2543
class WaveletPacket(BaseDict):
2644
"""One dimensional wavelet packets."""
2745

0 commit comments

Comments
 (0)