Skip to content

Commit 0dc559f

Browse files
author
Jan Buethe
committed
added some bwe-related stuff
1 parent 5667867 commit 0dc559f

File tree

3 files changed

+89
-0
lines changed

3 files changed

+89
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
import scipy.signal
3+
4+
5+
from utils.layers.fir import FIR
6+
7+
class TDLowpass(torch.nn.Module):
8+
def __init__(self, numtaps, cutoff, power=2):
9+
super().__init__()
10+
11+
self.b = scipy.signal.firwin(numtaps, cutoff)
12+
self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
13+
self.power = power
14+
15+
def forward(self, y_true, y_pred):
16+
17+
assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
18+
19+
diff = y_true - y_pred
20+
diff_lp = torch.nn.functional.conv1d(diff, self.weight)
21+
22+
loss = torch.mean(torch.abs(diff_lp ** self.power))
23+
24+
return loss, diff_lp
25+
26+
def get_freqz(self):
27+
freq, response = scipy.signal.freqz(self.b)
28+
29+
return freq, response
30+
31+
32+
33+
34+

dnn/torch/osce/silk_16_to_48.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import argparse
2+
3+
from scipy.io import wavfile
4+
import torch
5+
import numpy as np
6+
7+
from utils.layers.silk_upsampler import SilkUpsampler
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument("input", type=str, help="input wave file")
11+
parser.add_argument("output", type=str, help="output wave file")
12+
13+
if __name__ == "__main__":
14+
args = parser.parse_args()
15+
16+
fs, x = wavfile.read(args.input)
17+
18+
# being lazy for now
19+
assert fs == 16000 and x.dtype == np.int16
20+
21+
x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1)
22+
23+
upsampler = SilkUpsampler()
24+
y = upsampler(x)
25+
26+
y = y.squeeze().numpy().astype(np.int16)
27+
28+
wavfile.write(args.output, 48000, y[13:])

dnn/torch/osce/utils/layers/fir.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
import scipy.signal
3+
import torch
4+
from torch import nn
5+
import torch.nn.functional as F
6+
7+
8+
class FIR(nn.Module):
9+
def __init__(self, numtaps, bands, desired, fs=2):
10+
super().__init__()
11+
12+
if numtaps % 2 == 0:
13+
print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
14+
numtaps += 1
15+
16+
a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
17+
18+
self.weight = torch.from_numpy(a.astype(np.float32))
19+
20+
def forward(self, x):
21+
num_channels = x.size(1)
22+
23+
weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
24+
25+
y = F.conv1d(x, weight, groups=num_channels)
26+
27+
return y

0 commit comments

Comments
 (0)