-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathvad.py
More file actions
53 lines (41 loc) · 1.44 KB
/
vad.py
File metadata and controls
53 lines (41 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import functools
import numpy as np
@functools.lru_cache
def get_vad_model(model_dir="pretrained_models"):
# now is silero_vad v5 model
return SileroVADModel(f"{model_dir}/silero_vad.onnx")
class SileroVADModel:
def __init__(self, path):
try:
import onnxruntime
except ImportError as e:
raise RuntimeError(
"Applying the VAD filter requires the onnxruntime package"
) from e
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
opts.log_severity_level = 4
self.session = onnxruntime.InferenceSession(
path,
providers=["CPUExecutionProvider"],
sess_options=opts,
)
def get_initial_state(self, batch_size: int):
return np.zeros((2, batch_size, 128), dtype=np.float32)
def __call__(self, x, state, sr: int):
if len(x.shape) == 1:
x = np.expand_dims(x, 0)
if len(x.shape) > 2:
raise ValueError(
f"Too many dimensions for input audio chunk {len(x.shape)}"
)
if sr/x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
ort_inputs = {
"input": x,
"state": state,
"sr": np.array(sr, dtype="int64"),
}
out, state = self.session.run(None, ort_inputs)
return out, state