Skip to content

Commit b7e349d

Browse files
committed
use cuda with onnx if available
1 parent 9991303 commit b7e349d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

basic_pitch/inference.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ def __init__(self, model_path: Union[pathlib.Path, str]):
129129
present.append("ONNX")
130130
try:
131131
self.model_type = Model.MODEL_TYPES.ONNX
132-
self.model = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"])
132+
providers = ["CPUExecutionProvider"]
133+
if "CUDAExecutionProvider" in ort.get_available_providers():
134+
providers.insert(0, "CUDAExecutionProvider")
135+
self.model = ort.InferenceSession(str(model_path), providers=providers)
133136
return
134137
except Exception as e:
135138
if str(model_path).endswith(".onnx"):

0 commit comments

Comments
 (0)