1- import pywt
2- import ptwt
3- import torch
4- import numpy as np
51import time
62from typing import NamedTuple
73
84import matplotlib .pyplot as plt
5+ import numpy as np
6+ import pywt
7+ import torch
8+
9+ import ptwt
10+
911
1012class WaveletTuple (NamedTuple ):
1113 """Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""
@@ -24,15 +26,15 @@ def _set_up_wavelet_tuple(wavelet, dtype):
2426 torch .tensor (wavelet .rec_hi ).type (dtype ),
2527 )
2628
29+
2730def _jit_wavedec_fun (data , wavelet ):
2831 return ptwt .wavedec (data , wavelet , "periodic" , level = 10 )
2932
3033
31- if __name__ == ' __main__' :
34+ if __name__ == " __main__" :
3235 length = 1e6
3336 repetitions = 100
3437
35-
3638 pywt_time_cpu = []
3739 ptwt_time_cpu = []
3840 ptwt_time_gpu = []
@@ -56,10 +58,10 @@ def _jit_wavedec_fun(data, wavelet):
5658
5759 wavelet = _set_up_wavelet_tuple (pywt .Wavelet ("db5" ), torch .float32 )
5860 jit_wavedec = torch .jit .trace (
59- _jit_wavedec_fun ,
60- (data , wavelet ),
61- strict = False ,
62- )
61+ _jit_wavedec_fun ,
62+ (data , wavelet ),
63+ strict = False ,
64+ )
6365
6466 for _ in range (repetitions ):
6567 data = np .random .randn (32 , int (length )).astype (np .float32 )
@@ -69,7 +71,6 @@ def _jit_wavedec_fun(data, wavelet):
6971 end = time .perf_counter ()
7072 ptwt_time_cpu_jit .append (end - start )
7173
72-
7374 for _ in range (repetitions ):
7475 data = np .random .randn (32 , int (length )).astype (np .float32 )
7576 data = torch .from_numpy (data ).cuda ()
@@ -82,10 +83,10 @@ def _jit_wavedec_fun(data, wavelet):
8283
8384 wavelet = _set_up_wavelet_tuple (pywt .Wavelet ("db5" ), torch .float32 )
8485 jit_wavedec = torch .jit .trace (
85- _jit_wavedec_fun ,
86- (data .cuda (), wavelet ),
87- strict = False ,
88- )
86+ _jit_wavedec_fun ,
87+ (data .cuda (), wavelet ),
88+ strict = False ,
89+ )
8990
9091 for _ in range (repetitions ):
9192 data = np .random .randn (32 , int (length )).astype (np .float32 )
@@ -95,14 +96,24 @@ def _jit_wavedec_fun(data, wavelet):
9596 res = jit_wavedec (data , wavelet )
9697 torch .cuda .synchronize ()
9798 end = time .perf_counter ()
98- ptwt_time_gpu_jit .append (end - start )
99+ ptwt_time_gpu_jit .append (end - start )
99100
100101 print ("1d fwt results" )
101- print (f"1d-pywt-cpu :{ np .mean (pywt_time_cpu ):5.5f} +- { np .std (pywt_time_cpu ):5.5f} " )
102- print (f"1d-ptwt-cpu :{ np .mean (ptwt_time_cpu ):5.5f} +- { np .std (ptwt_time_cpu ):5.5f} " )
103- print (f"1d-ptwt-cpu-jit:{ np .mean (ptwt_time_cpu_jit ):5.5f} +- { np .std (ptwt_time_cpu_jit ):5.5f} " )
104- print (f"1d-ptwt-gpu :{ np .mean (ptwt_time_gpu ):5.5f} +- { np .std (ptwt_time_gpu ):5.5f} " )
105- print (f"1d-ptwt-gpu-jit:{ np .mean (ptwt_time_gpu_jit ):5.5f} +- { np .std (ptwt_time_gpu_jit ):5.5f} " )
102+ print (
103+ f"1d-pywt-cpu :{ np .mean (pywt_time_cpu ):5.5f} +- { np .std (pywt_time_cpu ):5.5f} "
104+ )
105+ print (
106+ f"1d-ptwt-cpu :{ np .mean (ptwt_time_cpu ):5.5f} +- { np .std (ptwt_time_cpu ):5.5f} "
107+ )
108+ print (
109+ f"1d-ptwt-cpu-jit:{ np .mean (ptwt_time_cpu_jit ):5.5f} +- { np .std (ptwt_time_cpu_jit ):5.5f} "
110+ )
111+ print (
112+ f"1d-ptwt-gpu :{ np .mean (ptwt_time_gpu ):5.5f} +- { np .std (ptwt_time_gpu ):5.5f} "
113+ )
114+ print (
115+ f"1d-ptwt-gpu-jit:{ np .mean (ptwt_time_gpu_jit ):5.5f} +- { np .std (ptwt_time_gpu_jit ):5.5f} "
116+ )
106117 # plt.semilogy(pywt_time_cpu, label='pywt-cpu')
107118 # plt.semilogy(ptwt_time_cpu, label='ptwt-cpu')
108119 # plt.semilogy(ptwt_time_cpu_jit, label='ptwt-cpu-jit')
@@ -112,12 +123,24 @@ def _jit_wavedec_fun(data, wavelet):
112123 # plt.xlabel('repetition')
113124 # plt.ylabel('runtime [s]')
114125 # plt.show()
115- time_stack = np .stack ([pywt_time_cpu , ptwt_time_cpu , ptwt_time_cpu_jit , ptwt_time_gpu , ptwt_time_gpu_jit ], - 1 )
126+ time_stack = np .stack (
127+ [
128+ pywt_time_cpu ,
129+ ptwt_time_cpu ,
130+ ptwt_time_cpu_jit ,
131+ ptwt_time_gpu ,
132+ ptwt_time_gpu_jit ,
133+ ],
134+ - 1 ,
135+ )
116136 plt .boxplot (time_stack )
117- plt .yscale ('log' )
118- plt .xticks ([1 ,2 ,3 ,4 ,5 ], ["pywt-cpu" , "ptwt-cpu" , "ptwt-cpu-jit" , "ptwt-gpu" , "ptwt-gpu-jit" ])
137+ plt .yscale ("log" )
138+ plt .xticks (
139+ [1 , 2 , 3 , 4 , 5 ],
140+ ["pywt-cpu" , "ptwt-cpu" , "ptwt-cpu-jit" , "ptwt-gpu" , "ptwt-gpu-jit" ],
141+ )
119142 plt .xticks (rotation = 20 )
120- plt .ylabel (' runtime [s]' )
121- plt .title (' DWT-1D' )
122- plt .savefig (' ./figs/timeitconv1d.png' )
143+ plt .ylabel (" runtime [s]" )
144+ plt .title (" DWT-1D" )
145+ plt .savefig (" ./figs/timeitconv1d.png" )
123146 # plt.show()
0 commit comments