Skip to content

Commit 14d6fd9

Browse files
committed
make construct_2d_filt private.
1 parent 1ac84d6 commit 14d6fd9

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/ptwt/conv_transform_2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929

3030

31-
def construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
31+
def _construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
3232
"""Construct two-dimensional filters using outer products.
3333
3434
Args:
@@ -209,7 +209,7 @@ def wavedec2(
209209
dec_lo, dec_hi, _, _ = _get_filter_tensors(
210210
wavelet, flip=True, device=data.device, dtype=data.dtype
211211
)
212-
dec_filt = construct_2d_filt(lo=dec_lo, hi=dec_hi)
212+
dec_filt = _construct_2d_filt(lo=dec_lo, hi=dec_hi)
213213

214214
if level is None:
215215
level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet)
@@ -299,7 +299,7 @@ def waverec2(
299299
wavelet, flip=False, device=torch_device, dtype=torch_dtype
300300
)
301301
filt_len = rec_lo.shape[-1]
302-
rec_filt = construct_2d_filt(lo=rec_lo, hi=rec_hi)
302+
rec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi)
303303

304304
for c_pos, coeff_tuple in enumerate(coeffs[1:]):
305305
if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3:

src/ptwt/matmul_transform_2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
_check_if_tensor,
2323
_wavedec2d_unfold_channels_2d_list,
2424
_waverec2d_fold_channels_2d_list,
25-
construct_2d_filt,
25+
_construct_2d_filt,
2626
)
2727
from .matmul_transform import construct_boundary_a, construct_boundary_s, orthogonalize
2828
from .sparse_math import (
@@ -67,7 +67,7 @@ def _construct_a_2(
6767
dec_lo, dec_hi, _, _ = _get_filter_tensors(
6868
wavelet, flip=False, device=device, dtype=dtype
6969
)
70-
dec_filt = construct_2d_filt(lo=dec_lo, hi=dec_hi)
70+
dec_filt = _construct_2d_filt(lo=dec_lo, hi=dec_hi)
7171
ll, lh, hl, hh = dec_filt.squeeze(1)
7272
analysis_ll = construct_strided_conv2d_matrix(ll, height, width, mode=mode)
7373
analysis_lh = construct_strided_conv2d_matrix(lh, height, width, mode=mode)
@@ -113,7 +113,7 @@ def _construct_s_2(
113113
_, _, rec_lo, rec_hi = _get_filter_tensors(
114114
wavelet, flip=True, device=device, dtype=dtype
115115
)
116-
dec_filt = construct_2d_filt(lo=rec_lo, hi=rec_hi)
116+
dec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi)
117117
ll, lh, hl, hh = dec_filt.squeeze(1)
118118
synthesis_ll = construct_strided_conv2d_matrix(ll, height, width, mode=mode)
119119
synthesis_lh = construct_strided_conv2d_matrix(lh, height, width, mode=mode)

0 commit comments

Comments
 (0)