Skip to content

Commit 685fd9b

Browse files
authored
new: use cuda if available (#537)
* new: use cuda if available * fix: fix warning msg * fix: add missing import
1 parent b304a2a commit 685fd9b

23 files changed

+98
-71
lines changed

fastembed/common/onnx_model.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from numpy.typing import NDArray
1010
from tokenizers import Tokenizer
1111

12-
from fastembed.common.types import OnnxProvider, NumpyArray
12+
from fastembed.common.types import OnnxProvider, NumpyArray, Device
1313
from fastembed.parallel_processor import Worker
1414

1515
# Holds type of the embedding result
@@ -60,31 +60,35 @@ def _load_onnx_model(
6060
model_file: str,
6161
threads: int | None,
6262
providers: Sequence[OnnxProvider] | None = None,
63-
cuda: bool = False,
63+
cuda: bool | Device = Device.AUTO,
6464
device_id: int | None = None,
6565
extra_session_options: dict[str, Any] | None = None,
6666
) -> None:
6767
model_path = model_dir / model_file
6868
# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
69+
available_providers = ort.get_available_providers()
70+
cuda_available = "CUDAExecutionProvider" in available_providers
71+
explicit_cuda = cuda is True or cuda == Device.CUDA
6972

70-
if cuda and providers is not None:
73+
if explicit_cuda and providers is not None:
7174
warnings.warn(
72-
f"`cuda` and `providers` are mutually exclusive parameters, cuda: {cuda}, providers: {providers}",
75+
f"`cuda` and `providers` are mutually exclusive parameters, "
76+
f"cuda: {cuda}, providers: {providers}. If you'd like to use providers, cuda should be one of "
77+
f"[False, Device.CPU, Device.AUTO].",
7378
category=UserWarning,
7479
stacklevel=6,
7580
)
7681

7782
if providers is not None:
7883
onnx_providers = list(providers)
79-
elif cuda:
84+
elif explicit_cuda or (cuda == Device.AUTO and cuda_available):
8085
if device_id is None:
8186
onnx_providers = ["CUDAExecutionProvider"]
8287
else:
8388
onnx_providers = [("CUDAExecutionProvider", {"device_id": device_id})]
8489
else:
8590
onnx_providers = ["CPUExecutionProvider"]
8691

87-
available_providers = ort.get_available_providers()
8892
requested_provider_names: list[str] = []
8993
for provider in onnx_providers:
9094
# check providers available

fastembed/common/types.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1+
from enum import Enum
12
from pathlib import Path
2-
33
from typing import Any, TypeAlias
4+
45
import numpy as np
56
from numpy.typing import NDArray
67
from PIL import Image
78

89

10+
class Device(str, Enum):
11+
CPU = "cpu"
12+
CUDA = "cuda"
13+
AUTO = "auto"
14+
15+
916
PathInput: TypeAlias = str | Path
1017
ImageInput: TypeAlias = PathInput | Image.Image
1118

fastembed/image/image_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Iterable, Sequence, Type
22
from dataclasses import asdict
33

4-
from fastembed.common.types import NumpyArray
4+
from fastembed.common.types import NumpyArray, Device
55
from fastembed.common import ImageInput, OnnxProvider
66
from fastembed.image.image_embedding_base import ImageEmbeddingBase
77
from fastembed.image.onnx_embedding import OnnxImageEmbedding
@@ -51,7 +51,7 @@ def __init__(
5151
cache_dir: str | None = None,
5252
threads: int | None = None,
5353
providers: Sequence[OnnxProvider] | None = None,
54-
cuda: bool = False,
54+
cuda: bool | Device = Device.AUTO,
5555
device_ids: list[int] | None = None,
5656
lazy_load: bool = False,
5757
**kwargs: Any,

fastembed/image/onnx_embedding.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Iterable, Sequence, Type
22

33

4-
from fastembed.common.types import NumpyArray
4+
from fastembed.common.types import NumpyArray, Device
55
from fastembed.common import ImageInput, OnnxProvider
66
from fastembed.common.onnx_model import OnnxOutputContext
77
from fastembed.common.utils import define_cache_dir, normalize
@@ -63,10 +63,11 @@ class OnnxImageEmbedding(ImageEmbeddingBase, OnnxImageModel[NumpyArray]):
6363
def __init__(
6464
self,
6565
model_name: str,
66+
6667
cache_dir: str | None = None,
6768
threads: int | None = None,
6869
providers: Sequence[OnnxProvider] | None = None,
69-
cuda: bool = False,
70+
cuda: bool | Device = Device.AUTO,
7071
device_ids: list[int] | None = None,
7172
lazy_load: bool = False,
7273
device_id: int | None = None,
@@ -82,10 +83,11 @@ def __init__(
8283
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
8384
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
8485
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
85-
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
86-
Defaults to False.
86+
cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
87+
Defaults to Device.AUTO.
8788
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
88-
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
89+
workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
90+
with `providers`. Defaults to None.
8991
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
9092
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
9193
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.

fastembed/image/onnx_image_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from PIL import Image
99

1010
from fastembed.image.transform.operators import Compose
11-
from fastembed.common.types import NumpyArray
11+
from fastembed.common.types import NumpyArray, Device
1212
from fastembed.common import ImageInput, OnnxProvider
1313
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
1414
from fastembed.common.preprocessor_utils import load_preprocessor
@@ -53,7 +53,7 @@ def _load_onnx_model(
5353
model_file: str,
5454
threads: int | None,
5555
providers: Sequence[OnnxProvider] | None = None,
56-
cuda: bool = False,
56+
cuda: bool | Device = Device.AUTO,
5757
device_id: int | None = None,
5858
extra_session_options: dict[str, Any] | None = None,
5959
) -> None:
@@ -97,7 +97,7 @@ def _embed_images(
9797
batch_size: int = 256,
9898
parallel: int | None = None,
9999
providers: Sequence[OnnxProvider] | None = None,
100-
cuda: bool = False,
100+
cuda: bool | Device = Device.AUTO,
101101
device_ids: list[int] | None = None,
102102
local_files_only: bool = False,
103103
specific_model_path: str | None = None,

fastembed/late_interaction/colbert.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tokenizers import Encoding, Tokenizer
66

77
from fastembed.common.preprocessor_utils import load_tokenizer
8-
from fastembed.common.types import NumpyArray
8+
from fastembed.common.types import NumpyArray, Device
99
from fastembed.common import OnnxProvider
1010
from fastembed.common.onnx_model import OnnxOutputContext
1111
from fastembed.common.utils import define_cache_dir, iter_batch
@@ -143,7 +143,7 @@ def __init__(
143143
cache_dir: str | None = None,
144144
threads: int | None = None,
145145
providers: Sequence[OnnxProvider] | None = None,
146-
cuda: bool = False,
146+
cuda: bool | Device = Device.AUTO,
147147
device_ids: list[int] | None = None,
148148
lazy_load: bool = False,
149149
device_id: int | None = None,
@@ -159,10 +159,11 @@ def __init__(
159159
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
160160
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
161161
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
162-
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
163-
Defaults to False.
162+
cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
163+
Defaults to Device.AUTO.
164164
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
165-
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
165+
workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
166+
with `providers`. Defaults to None.
166167
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
167168
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
168169
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.

fastembed/late_interaction/late_interaction_text_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import asdict
33

44
from fastembed.common.model_description import DenseModelDescription
5-
from fastembed.common.types import NumpyArray
5+
from fastembed.common.types import NumpyArray, Device
66
from fastembed.common import OnnxProvider
77
from fastembed.late_interaction.colbert import Colbert
88
from fastembed.late_interaction.jina_colbert import JinaColbert
@@ -54,7 +54,7 @@ def __init__(
5454
cache_dir: str | None = None,
5555
threads: int | None = None,
5656
providers: Sequence[OnnxProvider] | None = None,
57-
cuda: bool = False,
57+
cuda: bool | Device = Device.AUTO,
5858
device_ids: list[int] | None = None,
5959
lazy_load: bool = False,
6060
**kwargs: Any,

fastembed/late_interaction_multimodal/colpali.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from fastembed.common import OnnxProvider, ImageInput
77
from fastembed.common.onnx_model import OnnxOutputContext
8-
from fastembed.common.types import NumpyArray
8+
from fastembed.common.types import NumpyArray, Device
99
from fastembed.common.utils import define_cache_dir, iter_batch
1010
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
1111
LateInteractionMultimodalEmbeddingBase,
@@ -49,7 +49,7 @@ def __init__(
4949
cache_dir: str | None = None,
5050
threads: int | None = None,
5151
providers: Sequence[OnnxProvider] | None = None,
52-
cuda: bool = False,
52+
cuda: bool | Device = Device.AUTO,
5353
device_ids: list[int] | None = None,
5454
lazy_load: bool = False,
5555
device_id: int | None = None,
@@ -65,10 +65,11 @@ def __init__(
6565
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
6666
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
6767
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
68-
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
69-
Defaults to False.
68+
cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
69+
Defaults to Device.AUTO.
7070
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
71-
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
71+
workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
72+
with `providers`. Defaults to None.
7273
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
7374
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
7475
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.

fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import asdict
33

44
from fastembed.common import OnnxProvider, ImageInput
5-
from fastembed.common.types import NumpyArray
5+
from fastembed.common.types import NumpyArray, Device
66
from fastembed.late_interaction_multimodal.colpali import ColPali
77

88
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
@@ -57,7 +57,7 @@ def __init__(
5757
cache_dir: str | None = None,
5858
threads: int | None = None,
5959
providers: Sequence[OnnxProvider] | None = None,
60-
cuda: bool = False,
60+
cuda: bool | Device = Device.AUTO,
6161
device_ids: list[int] | None = None,
6262
lazy_load: bool = False,
6363
**kwargs: Any,

fastembed/late_interaction_multimodal/onnx_multimodal_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from fastembed.common import OnnxProvider, ImageInput
1212
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
1313
from fastembed.common.preprocessor_utils import load_tokenizer, load_preprocessor
14-
from fastembed.common.types import NumpyArray
14+
from fastembed.common.types import NumpyArray, Device
1515
from fastembed.common.utils import iter_batch
1616
from fastembed.image.transform.operators import Compose
1717
from fastembed.parallel_processor import ParallelWorkerPool
@@ -62,7 +62,7 @@ def _load_onnx_model(
6262
model_file: str,
6363
threads: int | None,
6464
providers: Sequence[OnnxProvider] | None = None,
65-
cuda: bool = False,
65+
cuda: bool | Device = Device.AUTO,
6666
device_id: int | None = None,
6767
extra_session_options: dict[str, Any] | None = None,
6868
) -> None:
@@ -120,7 +120,7 @@ def _embed_documents(
120120
batch_size: int = 256,
121121
parallel: int | None = None,
122122
providers: Sequence[OnnxProvider] | None = None,
123-
cuda: bool = False,
123+
cuda: bool | Device = Device.AUTO,
124124
device_ids: list[int] | None = None,
125125
local_files_only: bool = False,
126126
specific_model_path: str | None = None,
@@ -191,7 +191,7 @@ def _embed_images(
191191
batch_size: int = 256,
192192
parallel: int | None = None,
193193
providers: Sequence[OnnxProvider] | None = None,
194-
cuda: bool = False,
194+
cuda: bool | Device = Device.AUTO,
195195
device_ids: list[int] | None = None,
196196
local_files_only: bool = False,
197197
specific_model_path: str | None = None,

0 commit comments

Comments
 (0)