|
22 | 22 | _check_if_tensor, |
23 | 23 | _wavedec2d_unfold_channels_2d_list, |
24 | 24 | _waverec2d_fold_channels_2d_list, |
25 | | - construct_2d_filt, |
| 25 | + _construct_2d_filt, |
26 | 26 | ) |
27 | 27 | from .matmul_transform import construct_boundary_a, construct_boundary_s, orthogonalize |
28 | 28 | from .sparse_math import ( |
@@ -67,7 +67,7 @@ def _construct_a_2( |
67 | 67 | dec_lo, dec_hi, _, _ = _get_filter_tensors( |
68 | 68 | wavelet, flip=False, device=device, dtype=dtype |
69 | 69 | ) |
70 | | - dec_filt = construct_2d_filt(lo=dec_lo, hi=dec_hi) |
| 70 | + dec_filt = _construct_2d_filt(lo=dec_lo, hi=dec_hi) |
71 | 71 | ll, lh, hl, hh = dec_filt.squeeze(1) |
72 | 72 | analysis_ll = construct_strided_conv2d_matrix(ll, height, width, mode=mode) |
73 | 73 | analysis_lh = construct_strided_conv2d_matrix(lh, height, width, mode=mode) |
@@ -113,7 +113,7 @@ def _construct_s_2( |
113 | 113 | _, _, rec_lo, rec_hi = _get_filter_tensors( |
114 | 114 | wavelet, flip=True, device=device, dtype=dtype |
115 | 115 | ) |
116 | | - dec_filt = construct_2d_filt(lo=rec_lo, hi=rec_hi) |
| 116 | + dec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi) |
117 | 117 | ll, lh, hl, hh = dec_filt.squeeze(1) |
118 | 118 | synthesis_ll = construct_strided_conv2d_matrix(ll, height, width, mode=mode) |
119 | 119 | synthesis_lh = construct_strided_conv2d_matrix(lh, height, width, mode=mode) |
|
0 commit comments