Skip to content

Commit 82bec7d

Browse files
rillianjmvalin
authored andcommitted
Remove trailing whitespace from the dnn torch modules
This is general best practice, but we also have a failing github action complaining about these new files. Signed-off-by: Jean-Marc Valin <[email protected]>
1 parent b3ed2bb commit 82bec7d

File tree

3 files changed

+22
-27
lines changed

3 files changed

+22
-27
lines changed

dnn/torch/osce/losses/td_lowpass.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,23 @@
77
class TDLowpass(torch.nn.Module):
88
def __init__(self, numtaps, cutoff, power=2):
99
super().__init__()
10-
10+
1111
self.b = scipy.signal.firwin(numtaps, cutoff)
1212
self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
1313
self.power = power
14-
14+
1515
def forward(self, y_true, y_pred):
16-
16+
1717
assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
18-
18+
1919
diff = y_true - y_pred
2020
diff_lp = torch.nn.functional.conv1d(diff, self.weight)
21-
21+
2222
loss = torch.mean(torch.abs(diff_lp ** self.power))
23-
23+
2424
return loss, diff_lp
25-
25+
2626
def get_freqz(self):
2727
freq, response = scipy.signal.freqz(self.b)
28-
28+
2929
return freq, response
30-
31-
32-
33-
34-

dnn/torch/osce/silk_16_to_48.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212

1313
if __name__ == "__main__":
1414
args = parser.parse_args()
15-
15+
1616
fs, x = wavfile.read(args.input)
1717

1818
# being lazy for now
1919
assert fs == 16000 and x.dtype == np.int16
20-
20+
2121
x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1)
22-
22+
2323
upsampler = SilkUpsampler()
2424
y = upsampler(x)
25-
25+
2626
y = y.squeeze().numpy().astype(np.int16)
27-
28-
wavfile.write(args.output, 48000, y[13:])
27+
28+
wavfile.write(args.output, 48000, y[13:])

dnn/torch/osce/utils/layers/fir.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,20 @@
88
class FIR(nn.Module):
99
def __init__(self, numtaps, bands, desired, fs=2):
1010
super().__init__()
11-
11+
1212
if numtaps % 2 == 0:
1313
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
1414
numtaps += 1
15-
15+
1616
a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
17-
17+
1818
self.weight = torch.from_numpy(a.astype(np.float32))
19-
19+
2020
def forward(self, x):
2121
num_channels = x.size(1)
22-
22+
2323
weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
24-
24+
2525
y = F.conv1d(x, weight, groups=num_channels)
26-
27-
return y
26+
27+
return y

0 commit comments

Comments
 (0)