|
1 | 1 | (ns simulflow.vad.silero |
2 | 2 | (:require |
| 3 | + [clojure.java.io :as io] |
3 | 4 | [simulflow.utils.audio :as audio] |
4 | 5 | [simulflow.vad.core :as vad] |
5 | 6 | [taoensso.telemere :as t]) |
|
10 | 11 | ;; Constants |
11 | 12 | (def ^:private model-reset-time-ms 5000) |
12 | 13 | (def ^:private supported-sample-rates #{8000 16000}) |
| 14 | +(def ^:private default-model-resource "silero_vad.onnx") |
| 15 | + |
| 16 | +(defn- load-model-resource |
| 17 | + "Load the ONNX model from resources as a byte array." |
| 18 | + ([] |
| 19 | + (load-model-resource default-model-resource)) |
| 20 | + ([resource-name] |
| 21 | + (if-let [resource (io/resource resource-name)] |
| 22 | + (with-open [in (io/input-stream resource)] |
| 23 | + (let [baos (java.io.ByteArrayOutputStream.)] |
| 24 | + (io/copy in baos) |
| 25 | + (.toByteArray baos))) |
| 26 | + (throw (ex-info (str "Model resource not found: " resource-name) |
| 27 | + {:resource resource-name}))))) |
13 | 28 |
|
14 | 29 | ;; ONNX Model wrapper |
15 | 30 | (defn- create-session-options [] |
|
94 | 109 |
|
95 | 110 | (defn- create-silero-onnx-model |
96 | 111 | "Create a new Silero ONNX model instance" |
97 | | - [model-path] |
98 | | - (try |
99 | | - (let [env (OrtEnvironment/getEnvironment) |
100 | | - opts (create-session-options) |
101 | | - session (.createSession env model-path opts) |
102 | | - state (atom {:state (make-array Float/TYPE 2 1 128) |
103 | | - :context (make-array Float/TYPE 0 0) |
104 | | - :last-sr 0 |
105 | | - :last-batch-size 0})] |
106 | | - (reify SileroOnnxModel |
107 | | - (reset-states! [_] |
108 | | - (reset-states! _ 1)) |
109 | | - (reset-states! [_ batch-size] |
110 | | - (reset! state {:state (make-array Float/TYPE 2 batch-size 128) |
111 | | - :context (make-array Float/TYPE batch-size 0) |
112 | | - :last-sr 0 |
113 | | - :last-batch-size 0})) |
114 | | - (call-model [this x sr] |
115 | | - (let [{:keys [x sr]} (validate-input x sr) |
116 | | - number-samples (if (= sr 16000) 512 256) |
117 | | - batch-size (count x) |
118 | | - context-size (if (= sr 16000) 64 32)] |
119 | | - ;; Validate exact sample count |
120 | | - (when (not= (count (first x)) number-samples) |
121 | | - (throw (IllegalArgumentException. |
122 | | - (str "Provided number of samples is " (count (first x)) |
123 | | - " (Supported values: 256 for 8000 sample rate, 512 for 16000)")))) |
124 | | - ;; Handle state resets |
125 | | - (let [current-state @state |
126 | | - need-reset? (or (zero? (:last-batch-size current-state)) |
127 | | - (and (not= (:last-sr current-state) 0) |
128 | | - (not= (:last-sr current-state) sr)) |
129 | | - (and (not= (:last-batch-size current-state) 0) |
130 | | - (not= (:last-batch-size current-state) batch-size)))] |
131 | | - (when need-reset? |
132 | | - (reset-states! this batch-size)) |
133 | | - ;; Re-read state after potential reset |
134 | | - (let [current-state @state |
135 | | - context (if (zero? (count (:context current-state))) |
136 | | - (make-array Float/TYPE batch-size context-size) |
137 | | - (:context current-state)) |
138 | | - x-with-context (concatenate-arrays context (to-float-array-2d x)) |
139 | | - env (OrtEnvironment/getEnvironment)] |
140 | | - (with-open [input-tensor (OnnxTensor/createTensor env x-with-context) |
141 | | - state-tensor (OnnxTensor/createTensor env (:state current-state)) |
142 | | - sr-tensor (OnnxTensor/createTensor env (long-array [sr]))] |
143 | | - (let [inputs (doto (HashMap.) |
144 | | - (.put "input" input-tensor) |
145 | | - (.put "state" state-tensor) |
146 | | - (.put "sr" sr-tensor))] |
147 | | - (with-open [outputs (.run session inputs)] |
148 | | - (let [output (.getValue (.get outputs 0)) |
149 | | - new-state (.getValue (.get outputs 1)) |
150 | | - new-context (get-last-columns x-with-context context-size)] |
151 | | - (swap! state assoc |
152 | | - :state new-state |
153 | | - :context new-context |
154 | | - :last-sr sr |
155 | | - :last-batch-size batch-size) |
156 | | - ;; Return confidence for first batch item |
157 | | - (aget output 0 0))))))))) |
158 | | - (close-model! [_] |
159 | | - (try |
160 | | - (.close session) |
161 | | - (catch Exception e |
162 | | - (t/log! {:level :warn |
163 | | - :id :silero-onnx |
164 | | - :error e} "Error closing Silero ONNX model")))))) |
165 | | - (catch Exception e |
166 | | - (t/log! {:level :error |
167 | | - :error e} "Failed to create Silero ONNX model") |
168 | | - (throw e)))) |
| 112 | + ([] |
| 113 | + (create-silero-onnx-model default-model-resource)) |
| 114 | + ([model-path] |
| 115 | + (try |
| 116 | + (let [env (OrtEnvironment/getEnvironment) |
| 117 | + opts (create-session-options) |
| 118 | + session (if (and model-path (.exists (io/file model-path))) |
| 119 | + ;; Use file path if it exists |
| 120 | + (.createSession env model-path opts) |
| 121 | + ;; Load from resources as byte array |
| 122 | + (let [model-bytes (load-model-resource (or model-path default-model-resource))] |
| 123 | + (.createSession env model-bytes opts))) |
| 124 | + state (atom {:state (make-array Float/TYPE 2 1 128) |
| 125 | + :context (make-array Float/TYPE 0 0) |
| 126 | + :last-sr 0 |
| 127 | + :last-batch-size 0})] |
| 128 | + (reify SileroOnnxModel |
| 129 | + (reset-states! [_] |
| 130 | + (reset-states! _ 1)) |
| 131 | + (reset-states! [_ batch-size] |
| 132 | + (reset! state {:state (make-array Float/TYPE 2 batch-size 128) |
| 133 | + :context (make-array Float/TYPE batch-size 0) |
| 134 | + :last-sr 0 |
| 135 | + :last-batch-size 0})) |
| 136 | + (call-model [this x sr] |
| 137 | + (let [{:keys [x sr]} (validate-input x sr) |
| 138 | + number-samples (if (= sr 16000) 512 256) |
| 139 | + batch-size (count x) |
| 140 | + context-size (if (= sr 16000) 64 32)] |
| 141 | + ;; Validate exact sample count |
| 142 | + (when (not= (count (first x)) number-samples) |
| 143 | + (throw (IllegalArgumentException. |
| 144 | + (str "Provided number of samples is " (count (first x)) |
| 145 | + " (Supported values: 256 for 8000 sample rate, 512 for 16000)")))) |
| 146 | + ;; Handle state resets |
| 147 | + (let [current-state @state |
| 148 | + need-reset? (or (zero? (:last-batch-size current-state)) |
| 149 | + (and (not= (:last-sr current-state) 0) |
| 150 | + (not= (:last-sr current-state) sr)) |
| 151 | + (and (not= (:last-batch-size current-state) 0) |
| 152 | + (not= (:last-batch-size current-state) batch-size)))] |
| 153 | + (when need-reset? |
| 154 | + (reset-states! this batch-size)) |
| 155 | + ;; Re-read state after potential reset |
| 156 | + (let [current-state @state |
| 157 | + context (if (zero? (count (:context current-state))) |
| 158 | + (make-array Float/TYPE batch-size context-size) |
| 159 | + (:context current-state)) |
| 160 | + x-with-context (concatenate-arrays context (to-float-array-2d x)) |
| 161 | + env (OrtEnvironment/getEnvironment)] |
| 162 | + (with-open [input-tensor (OnnxTensor/createTensor env x-with-context) |
| 163 | + state-tensor (OnnxTensor/createTensor env (:state current-state)) |
| 164 | + sr-tensor (OnnxTensor/createTensor env (long-array [sr]))] |
| 165 | + (let [inputs (doto (HashMap.) |
| 166 | + (.put "input" input-tensor) |
| 167 | + (.put "state" state-tensor) |
| 168 | + (.put "sr" sr-tensor))] |
| 169 | + (with-open [outputs (.run session inputs)] |
| 170 | + (let [output (.getValue (.get outputs 0)) |
| 171 | + new-state (.getValue (.get outputs 1)) |
| 172 | + new-context (get-last-columns x-with-context context-size)] |
| 173 | + (swap! state assoc |
| 174 | + :state new-state |
| 175 | + :context new-context |
| 176 | + :last-sr sr |
| 177 | + :last-batch-size batch-size) |
| 178 | + ;; Return confidence for first batch item |
| 179 | + (aget output 0 0))))))))) |
| 180 | + (close-model! [_] |
| 181 | + (try |
| 182 | + (.close session) |
| 183 | + (catch Exception e |
| 184 | + (t/log! {:level :warn |
| 185 | + :id :silero-onnx |
| 186 | + :error e} "Error closing Silero ONNX model")))))) |
| 187 | + (catch Exception e |
| 188 | + (t/log! {:level :error |
| 189 | + :error e} "Failed to create Silero ONNX model") |
| 190 | + (throw e))))) |
169 | 191 |
|
170 | 192 | (comment |
171 | 193 |
|
172 | 194 | (defn test-model [] |
173 | | - (let [model (create-silero-onnx-model "resources/silero_vad.onnx") |
| 195 | + (let [model (create-silero-onnx-model nil) ; Use bundled resource |
174 | 196 | ;; Create 512 samples of silence |
175 | 197 | silence (float-array 512 0.0) |
176 | 198 | ;; Create 512 samples of noise |
|
223 | 245 | "Create a Silero VAD analyzer instance with audio accumulation. |
224 | 246 | |
225 | 247 | Options: |
226 | | - - :model-path - Path to the silero_vad.onnx model file |
| 248 | + - :model-path - Path to a custom silero_vad.onnx model file or java resource (default: uses bundled model) |
227 | 249 | - :sample-rate - Audio sample rate (8000 or 16000 Hz, default: 16000) |
228 | 250 | - :vad/min-confidence - Threshold for voice activity detection (default: 0.7) |
229 | 251 | - :vad/min-speech-duration-ms - Minimum speech duration in ms (default: 200) |
|
232 | 254 | ([{:keys [model-path sample-rate] |
233 | 255 | :vad/keys [min-confidence min-speech-duration-ms min-silence-duration-ms] |
234 | 256 | :or {sample-rate 16000 |
235 | | - model-path "resources/silero_vad.onnx" |
236 | 257 | min-speech-duration-ms (:vad/min-speech-duration-ms vad/default-params) |
237 | 258 | min-silence-duration-ms (:vad/min-silence-duration-ms vad/default-params) |
238 | 259 | min-confidence (:vad/min-confidence vad/default-params)}}] |
239 | 260 | (when-not (contains? supported-sample-rates sample-rate) |
240 | 261 | (throw (IllegalArgumentException. |
241 | 262 | (str "Sampling rate not supported, only available for " supported-sample-rates)))) |
242 | | - (let [model (create-silero-onnx-model (or model-path "resources/silero_vad.onnx")) |
| 263 | + (let [model (create-silero-onnx-model model-path) |
243 | 264 | frames-required (if (= sample-rate 16000) 512 256) |
244 | 265 | bytes-per-frame 2 ; 16-bit PCM |
245 | 266 | bytes-required (* frames-required bytes-per-frame) |
|
305 | 326 | 0.0))) |
306 | 327 | (cleanup [_] |
307 | 328 | (close-model! model)))))) |
| 329 | + |
| 330 | +(comment |
| 331 | + (def vad (create-silero-vad)) |
| 332 | + |
| 333 | + (vad/cleanup vad)) |
0 commit comments