Skip to content

Commit 0a551bf

Browse files
committed
last change prolly, rm magic numbers here too
1 parent 9e66c64 commit 0a551bf

File tree

1 file changed

+34
-16
lines changed

1 file changed

+34
-16
lines changed

basic_pitch/layers/nnaudio.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,26 @@
2626

2727
import scipy.signal
2828

29+
DEFAULT_BAND_CENTER = 0.5
30+
DEFAULT_KERNEL_LENGTH = 256
31+
DEFAULT_TRANSITION_BANDWIDTH = 0.03
32+
DEFAULT_DTYPE = tf.float32
33+
DEFAULT_WINDOW_BANDWIDTH = 1.5
34+
DEFAULT_CQT_HOP_LENGTH = 512
35+
DEFAULT_CQT_FMIN = 32.70
36+
DEFAULT_CQT_N_BINS = 84
37+
DEFAULT_CQT_BINS_PER_OCTAVE = 12
38+
DEFAULT_CQT_BASIS_NORM = 1
39+
DEFAULT_CQT_WINDOW = "hann"
40+
DEFAULT_CQT_PAD_MODE = "reflect"
41+
DEFAULT_CQT_OUTPUT_FORMAT = "Magnitude"
42+
DEFAULT_LOW_PASS_TRANSITION_BANDWIDTH = 0.001
2943

3044
def create_lowpass_filter(
31-
band_center: float = 0.5,
32-
kernel_length: int = 256,
33-
transition_bandwidth: float = 0.03,
34-
dtype: tf.dtypes.DType = tf.float32,
45+
band_center: float = DEFAULT_BAND_CENTER,
46+
kernel_length: int = DEFAULT_KERNEL_LENGTH,
47+
transition_bandwidth: float = DEFAULT_TRANSITION_BANDWIDTH,
48+
dtype: tf.dtypes.DType = DEFAULT_DTYPE,
3549
) -> np.ndarray:
3650
"""
3751
Calculate the highest frequency we need to preserve and the lowest frequency we allow
@@ -106,15 +120,15 @@ def get_early_downsample_params(
106120
) -> Tuple[Union[float, int], int, float, np.array, bool]:
107121
"""Compute downsampling parameters used for early downsampling"""
108122

109-
window_bandwidth = 1.5 # for hann window
123+
window_bandwidth = DEFAULT_WINDOW_BANDWIDTH # for hann window
110124
filter_cutoff = fmax_t * (1 + 0.5 * window_bandwidth / Q)
111125
sr, hop_length, downsample_factor = early_downsample(sr, hop_length, n_octaves, sr // 2, filter_cutoff)
112126
if downsample_factor != 1:
113127
earlydownsample = True
114128
early_downsample_filter = create_lowpass_filter(
115129
band_center=1 / downsample_factor,
116-
kernel_length=256,
117-
transition_bandwidth=0.03,
130+
kernel_length=DEFAULT_KERNEL_LENGTH,
131+
transition_bandwidth=DEFAULT_TRANSITION_BANDWIDTH,
118132
dtype=dtype,
119133
)
120134
else:
@@ -455,19 +469,19 @@ class CQT2010v2(tf.keras.layers.Layer):
455469
def __init__(
456470
self,
457471
sr: int = 22050,
458-
hop_length: int = 512,
459-
fmin: float = 32.70,
472+
hop_length: int = DEFAULT_CQT_HOP_LENGTH,
473+
fmin: float = DEFAULT_CQT_FMIN,
460474
fmax: Optional[float] = None,
461-
n_bins: int = 84,
475+
n_bins: int = DEFAULT_CQT_N_BINS,
462476
filter_scale: int = 1,
463-
bins_per_octave: int = 12,
477+
bins_per_octave: int = DEFAULT_CQT_BINS_PER_OCTAVE,
464478
norm: bool = True,
465-
basis_norm: int = 1,
466-
window: str = "hann",
467-
pad_mode: str = "reflect",
479+
basis_norm: int = DEFAULT_CQT_BASIS_NORM,
480+
window: str = DEFAULT_CQT_WINDOW,
481+
pad_mode: str = DEFAULT_CQT_PAD_MODE,
468482
earlydownsample: bool = True,
469483
trainable: bool = False,
470-
output_format: str = "Magnitude",
484+
output_format: str = DEFAULT_CQT_OUTPUT_FORMAT,
471485
match_torch_exactly: bool = True,
472486
):
473487
super().__init__()
@@ -516,7 +530,11 @@ def build(self, input_shape: tf.TensorShape) -> None:
516530
# This will be used to calculate filter_cutoff and creating CQT kernels
517531
Q = float(self.filter_scale) / (2 ** (1 / self.bins_per_octave) - 1)
518532

519-
self.lowpass_filter = create_lowpass_filter(band_center=0.5, kernel_length=256, transition_bandwidth=0.001)
533+
self.lowpass_filter = create_lowpass_filter(
534+
band_center=DEFAULT_BAND_CENTER,
535+
kernel_length=DEFAULT_KERNEL_LENGTH,
536+
transition_bandwidth=DEFAULT_LOW_PASS_TRANSITION_BANDWIDTH,
537+
)
520538

521539
# Calculate num of filter requires for the kernel
522540
# n_octaves determines how many resampling requires for the CQT

0 commit comments

Comments
 (0)