|
1 | 1 | import warnings |
2 | 2 | from dataclasses import dataclass |
3 | 3 | from pathlib import Path |
4 | | -from typing import Any, Generic, Iterable, Optional, Sequence, Type, TypeVar |
| 4 | +from typing import Any, Generic, Iterable, Optional, Sequence, Type, TypeVar, Union |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import onnxruntime as ort |
8 | 8 |
|
9 | 9 | from numpy.typing import NDArray |
10 | 10 | from tokenizers import Tokenizer |
11 | 11 |
|
12 | | -from fastembed.common.types import OnnxProvider, NumpyArray |
| 12 | +from fastembed.common.types import OnnxProvider, NumpyArray, Device |
13 | 13 | from fastembed.parallel_processor import Worker |
14 | 14 |
|
15 | 15 | # Holds type of the embedding result |
@@ -60,31 +60,35 @@ def _load_onnx_model( |
60 | 60 | model_file: str, |
61 | 61 | threads: Optional[int], |
62 | 62 | providers: Optional[Sequence[OnnxProvider]] = None, |
63 | | - cuda: bool = False, |
| 63 | + cuda: Union[bool, Device] = Device.AUTO, |
64 | 64 | device_id: Optional[int] = None, |
65 | 65 | extra_session_options: Optional[dict[str, Any]] = None, |
66 | 66 | ) -> None: |
67 | 67 | model_path = model_dir / model_file |
68 | 68 | # List of Execution Providers: https://onnxruntime.ai/docs/execution-providers |
| 69 | + available_providers = ort.get_available_providers() |
| 70 | + cuda_available = "CUDAExecutionProvider" in available_providers |
| 71 | + explicit_cuda = cuda is True or cuda == Device.CUDA |
69 | 72 |
|
70 | | - if cuda and providers is not None: |
| 73 | + if explicit_cuda and providers is not None: |
71 | 74 | warnings.warn( |
72 | | - f"`cuda` and `providers` are mutually exclusive parameters, cuda: {cuda}, providers: {providers}", |
| 75 | + f"`cuda` and `providers` are mutually exclusive parameters, " |
| 76 | + f"cuda: {cuda}, providers: {providers}. If you'd like to use providers, cuda should be one of " |
| 77 | + f"[True, Device.CPU, Device.AUTO].", |
73 | 78 | category=UserWarning, |
74 | 79 | stacklevel=6, |
75 | 80 | ) |
76 | 81 |
|
77 | 82 | if providers is not None: |
78 | 83 | onnx_providers = list(providers) |
79 | | - elif cuda: |
| 84 | + elif explicit_cuda or (cuda == Device.AUTO and cuda_available): |
80 | 85 | if device_id is None: |
81 | 86 | onnx_providers = ["CUDAExecutionProvider"] |
82 | 87 | else: |
83 | 88 | onnx_providers = [("CUDAExecutionProvider", {"device_id": device_id})] |
84 | 89 | else: |
85 | 90 | onnx_providers = ["CPUExecutionProvider"] |
86 | 91 |
|
87 | | - available_providers = ort.get_available_providers() |
88 | 92 | requested_provider_names: list[str] = [] |
89 | 93 | for provider in onnx_providers: |
90 | 94 | # check providers available |
|
0 commit comments