@@ -26,18 +26,18 @@ def _compare_trees1(
2626
2727 if transform_mode :
2828 twp = WaveletPacket (
29- None , wavelet , mode = ptwt_boundary , max_level = max_lev
30- ).transform (torch .from_numpy (data ), max_level = max_lev )
29+ None , wavelet , mode = ptwt_boundary , maxlevel = max_lev
30+ ).transform (torch .from_numpy (data ), maxlevel = max_lev )
3131 else :
3232 twp = WaveletPacket (
33- torch .from_numpy (data ), wavelet , mode = ptwt_boundary , max_level = max_lev
33+ torch .from_numpy (data ), wavelet , mode = ptwt_boundary , maxlevel = max_lev
3434 )
3535
3636 # if multiple_transform flag is set, recalculcate the packets
3737 if multiple_transforms :
38- twp .transform (torch .from_numpy (data ), max_level = max_lev )
38+ twp .transform (torch .from_numpy (data ), maxlevel = max_lev )
3939
40- nodes = twp .get_level (twp .max_level )
40+ nodes = twp .get_level (twp .maxlevel )
4141 twp_lst = []
4242 for node in nodes :
4343 twp_lst .append (twp [node ])
@@ -58,7 +58,7 @@ def _compare_trees1(
5858 np_res = np .concatenate (np_lst , - 1 )
5959 np_batches .append (np_res )
6060 np_batches = np .stack (np_batches , 0 )
61- assert wp .maxlevel == twp .max_level
61+ assert wp .maxlevel == twp .maxlevel
6262 assert np .allclose (torch_res , np_batches )
6363
6464
@@ -101,22 +101,22 @@ def _compare_trees2(
101101 if transform_mode :
102102 ptwt_wp_tree = WaveletPacket2D (
103103 None , wavelet = wavelet , mode = ptwt_boundary
104- ).transform (pt_data , max_level = max_lev )
104+ ).transform (pt_data , maxlevel = max_lev )
105105 else :
106106 ptwt_wp_tree = WaveletPacket2D (
107- pt_data , wavelet = wavelet , mode = ptwt_boundary , max_level = max_lev
107+ pt_data , wavelet = wavelet , mode = ptwt_boundary , maxlevel = max_lev
108108 )
109109
110110 # if multiple_transform flag is set, recalculcate the packets
111111 if multiple_transforms :
112- ptwt_wp_tree .transform (pt_data , max_level = max_lev )
112+ ptwt_wp_tree .transform (pt_data , maxlevel = max_lev )
113113
114114 packets = []
115115 for node in wp_keys :
116116 packet = ptwt_wp_tree ["" .join (node )]
117117 packets .append (packet )
118118 packets_pt = torch .stack (packets , 1 ).numpy ()
119- assert wp_tree .maxlevel == ptwt_wp_tree .max_level
119+ assert wp_tree .maxlevel == ptwt_wp_tree .maxlevel
120120 assert np .allclose (packets_pt , batch_np_packets )
121121
122122
@@ -292,11 +292,12 @@ def test_access_errors_2d():
292292
293293@pytest .mark .parametrize ("level" , [1 , 2 , 3 ])
294294@pytest .mark .parametrize ("base_key" , ["a" , "d" ])
295- @pytest .mark .parametrize ("length" , [64 , 128 ])
295+ @pytest .mark .parametrize ("length" , [63 , 64 , 128 ])
296+ @pytest .mark .parametrize ("batch_size" , [1 , 2 ])
296297@pytest .mark .parametrize ("wavelet" , ["db1" , "db2" , "sym4" ])
297- def test_inverse_packet_1d (level , base_key , length , wavelet ):
298+ def test_inverse_packet_1d (level , base_key , length , batch_size , wavelet ):
298299 """Test the 1d reconstruction code."""
299- signal = np .random .randn (1 , length )
300+ signal = np .random .randn (batch_size , length )
300301 mode = "reflect"
301302 wp = pywt .WaveletPacket (signal , wavelet , mode = mode , maxlevel = level )
302303 ptwp = WaveletPacket (torch .from_numpy (signal ), wavelet , mode = mode , maxlevel = level )
@@ -309,22 +310,22 @@ def test_inverse_packet_1d(level, base_key, length, wavelet):
309310
310311@pytest .mark .parametrize ("level" , [1 , 3 ])
311312@pytest .mark .parametrize ("base_key" , ["a" , "h" , "d" ])
312- @pytest .mark .parametrize ("size" , [(32 , 32 ), (32 , 64 )])
313+ @pytest .mark .parametrize ("size" , [(1 , 32 , 32 ), (2 , 31 , 64 )])
313314@pytest .mark .parametrize ("wavelet" , ["db1" , "db2" , "sym4" ])
314315def test_inverse_packet_2d (level , base_key , size , wavelet ):
315316 """Test the 2d reconstruction code."""
316- signal = np .random .randn (1 , size [0 ], size [1 ])
317+ signal = np .random .randn (size [0 ], size [1 ], size [ 2 ])
317318 mode = "reflect"
318319 wp = pywt .WaveletPacket2D (signal , wavelet , mode = mode , maxlevel = level )
319320 ptwp = WaveletPacket2D (torch .from_numpy (signal ), wavelet , mode = mode , maxlevel = level )
320321 wp [base_key * level ].data *= 0
321322 ptwp [base_key * level ].data *= 0
322323 wp .reconstruct (update = True )
323324 ptwp .reconstruct ()
324- assert np .allclose (wp ["" ].data , ptwp ["" ].numpy ()[:, : size [0 ], : size [1 ]])
325+ assert np .allclose (wp ["" ].data , ptwp ["" ].numpy ()[:, : size [1 ], : size [2 ]])
325326
326327
327- def test_boundary_packet_1d ():
328+ def test_inverse_boundary_packet_1d ():
328329 """Test the 2d boundary reconstruction code."""
329330 signal = np .random .randn (1 , 16 )
330331 wp = pywt .WaveletPacket (signal , "haar" , mode = "zero" , maxlevel = 2 )
@@ -336,7 +337,7 @@ def test_boundary_packet_1d():
336337 assert np .allclose (wp ["" ].data , ptwp ["" ].numpy ()[:, :16 ])
337338
338339
339- def test_boundary_packet_2d ():
340+ def test_inverse_boundary_packet_2d ():
340341 """Test the 2d boundary reconstruction code."""
341342 size = (16 , 16 )
342343 level = 2
0 commit comments