Skip to content

Commit 256443e

Browse files
committed
new: use cuda if available
1 parent 428381c commit 256443e

23 files changed

+104
-77
lines changed

fastembed/common/onnx_model.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import warnings
22
from dataclasses import dataclass
33
from pathlib import Path
4-
from typing import Any, Generic, Iterable, Optional, Sequence, Type, TypeVar
4+
from typing import Any, Generic, Iterable, Optional, Sequence, Type, TypeVar, Union
55

66
import numpy as np
77
import onnxruntime as ort
88

99
from numpy.typing import NDArray
1010
from tokenizers import Tokenizer
1111

12-
from fastembed.common.types import OnnxProvider, NumpyArray
12+
from fastembed.common.types import OnnxProvider, NumpyArray, Device
1313
from fastembed.parallel_processor import Worker
1414

1515
# Holds type of the embedding result
@@ -60,31 +60,35 @@ def _load_onnx_model(
6060
model_file: str,
6161
threads: Optional[int],
6262
providers: Optional[Sequence[OnnxProvider]] = None,
63-
cuda: bool = False,
63+
cuda: Union[bool, Device] = Device.AUTO,
6464
device_id: Optional[int] = None,
6565
extra_session_options: Optional[dict[str, Any]] = None,
6666
) -> None:
6767
model_path = model_dir / model_file
6868
# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
69+
available_providers = ort.get_available_providers()
70+
cuda_available = "CUDAExecutionProvider" in available_providers
71+
explicit_cuda = cuda is True or cuda == Device.CUDA
6972

70-
if cuda and providers is not None:
73+
if explicit_cuda and providers is not None:
7174
warnings.warn(
72-
f"`cuda` and `providers` are mutually exclusive parameters, cuda: {cuda}, providers: {providers}",
75+
f"`cuda` and `providers` are mutually exclusive parameters, "
76+
f"cuda: {cuda}, providers: {providers}. If you'd like to use providers, cuda should be one of "
77+
f"[True, Device.CPU, Device.AUTO].",
7378
category=UserWarning,
7479
stacklevel=6,
7580
)
7681

7782
if providers is not None:
7883
onnx_providers = list(providers)
79-
elif cuda:
84+
elif explicit_cuda or (cuda == Device.AUTO and cuda_available):
8085
if device_id is None:
8186
onnx_providers = ["CUDAExecutionProvider"]
8287
else:
8388
onnx_providers = [("CUDAExecutionProvider", {"device_id": device_id})]
8489
else:
8590
onnx_providers = ["CPUExecutionProvider"]
8691

87-
available_providers = ort.get_available_providers()
8892
requested_provider_names: list[str] = []
8993
for provider in onnx_providers:
9094
# check providers available

fastembed/common/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
from pathlib import Path
23
import sys
34
from PIL import Image
@@ -23,3 +24,9 @@
2324
NDArray[np.int64],
2425
NDArray[np.int32],
2526
]
27+
28+
29+
class Device(str, Enum):
30+
CPU = "cpu"
31+
CUDA = "cuda"
32+
AUTO = "auto"

fastembed/image/image_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Iterable, Optional, Sequence, Type, Union
22
from dataclasses import asdict
33

4-
from fastembed.common.types import NumpyArray
4+
from fastembed.common.types import NumpyArray, Device
55
from fastembed.common import ImageInput, OnnxProvider
66
from fastembed.image.image_embedding_base import ImageEmbeddingBase
77
from fastembed.image.onnx_embedding import OnnxImageEmbedding
@@ -51,7 +51,7 @@ def __init__(
5151
cache_dir: Optional[str] = None,
5252
threads: Optional[int] = None,
5353
providers: Optional[Sequence[OnnxProvider]] = None,
54-
cuda: bool = False,
54+
cuda: Union[bool, Device] = Device.AUTO,
5555
device_ids: Optional[list[int]] = None,
5656
lazy_load: bool = False,
5757
**kwargs: Any,

fastembed/image/onnx_embedding.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Iterable, Optional, Sequence, Type, Union
22

33

4-
from fastembed.common.types import NumpyArray
4+
from fastembed.common.types import NumpyArray, Device
55
from fastembed.common import ImageInput, OnnxProvider
66
from fastembed.common.onnx_model import OnnxOutputContext
77
from fastembed.common.utils import define_cache_dir, normalize
@@ -66,7 +66,7 @@ def __init__(
6666
cache_dir: Optional[str] = None,
6767
threads: Optional[int] = None,
6868
providers: Optional[Sequence[OnnxProvider]] = None,
69-
cuda: bool = False,
69+
cuda: Union[bool, Device] = Device.AUTO,
7070
device_ids: Optional[list[int]] = None,
7171
lazy_load: bool = False,
7272
device_id: Optional[int] = None,
@@ -82,10 +82,11 @@ def __init__(
8282
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
8383
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
8484
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
85-
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
86-
Defaults to False.
85+
cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
86+
Defaults to Device.AUTO.
8787
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
88-
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
88+
workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
89+
with `providers`. Defaults to None.
8990
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
9091
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
9192
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.

fastembed/image/onnx_image_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from PIL import Image
99

1010
from fastembed.image.transform.operators import Compose
11-
from fastembed.common.types import NumpyArray
11+
from fastembed.common.types import NumpyArray, Device
1212
from fastembed.common import ImageInput, OnnxProvider
1313
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
1414
from fastembed.common.preprocessor_utils import load_preprocessor
@@ -53,7 +53,7 @@ def _load_onnx_model(
5353
model_file: str,
5454
threads: Optional[int],
5555
providers: Optional[Sequence[OnnxProvider]] = None,
56-
cuda: bool = False,
56+
cuda: Union[bool, Device] = Device.AUTO,
5757
device_id: Optional[int] = None,
5858
extra_session_options: Optional[dict[str, Any]] = None,
5959
) -> None:
@@ -97,7 +97,7 @@ def _embed_images(
9797
batch_size: int = 256,
9898
parallel: Optional[int] = None,
9999
providers: Optional[Sequence[OnnxProvider]] = None,
100-
cuda: bool = False,
100+
cuda: Union[bool, Device] = Device.AUTO,
101101
device_ids: Optional[list[int]] = None,
102102
local_files_only: bool = False,
103103
specific_model_path: Optional[str] = None,

fastembed/late_interaction/colbert.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tokenizers import Encoding, Tokenizer
66

77
from fastembed.common.preprocessor_utils import load_tokenizer
8-
from fastembed.common.types import NumpyArray
8+
from fastembed.common.types import NumpyArray, Device
99
from fastembed.common import OnnxProvider
1010
from fastembed.common.onnx_model import OnnxOutputContext
1111
from fastembed.common.utils import define_cache_dir, iter_batch
@@ -143,7 +143,7 @@ def __init__(
143143
cache_dir: Optional[str] = None,
144144
threads: Optional[int] = None,
145145
providers: Optional[Sequence[OnnxProvider]] = None,
146-
cuda: bool = False,
146+
cuda: Union[bool, Device] = Device.AUTO,
147147
device_ids: Optional[list[int]] = None,
148148
lazy_load: bool = False,
149149
device_id: Optional[int] = None,
@@ -159,10 +159,11 @@ def __init__(
159159
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
160160
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
161161
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
162-
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
163-
Defaults to False.
162+
cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
163+
Defaults to Device.AUTO.
164164
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
165-
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
165+
workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
166+
with `providers`. Defaults to None.
166167
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
167168
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
168169
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.

fastembed/late_interaction/late_interaction_text_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import asdict
33

44
from fastembed.common.model_description import DenseModelDescription
5-
from fastembed.common.types import NumpyArray
5+
from fastembed.common.types import NumpyArray, Device
66
from fastembed.common import OnnxProvider
77
from fastembed.late_interaction.colbert import Colbert
88
from fastembed.late_interaction.jina_colbert import JinaColbert
@@ -54,7 +54,7 @@ def __init__(
5454
cache_dir: Optional[str] = None,
5555
threads: Optional[int] = None,
5656
providers: Optional[Sequence[OnnxProvider]] = None,
57-
cuda: bool = False,
57+
cuda: Union[bool, Device] = Device.AUTO,
5858
device_ids: Optional[list[int]] = None,
5959
lazy_load: bool = False,
6060
**kwargs: Any,

fastembed/late_interaction_multimodal/colpali.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from fastembed.common import OnnxProvider, ImageInput
77
from fastembed.common.onnx_model import OnnxOutputContext
8-
from fastembed.common.types import NumpyArray
9-
from fastembed.common.utils import define_cache_dir, iter_batch
8+
from fastembed.common.types import NumpyArray, Device
9+
from fastembed.common.utils import define_cache_dir
1010
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
1111
LateInteractionMultimodalEmbeddingBase,
1212
)
@@ -49,7 +49,7 @@ def __init__(
4949
cache_dir: Optional[str] = None,
5050
threads: Optional[int] = None,
5151
providers: Optional[Sequence[OnnxProvider]] = None,
52-
cuda: bool = False,
52+
cuda: Union[bool, Device] = Device.AUTO,
5353
device_ids: Optional[list[int]] = None,
5454
lazy_load: bool = False,
5555
device_id: Optional[int] = None,
@@ -65,10 +65,11 @@ def __init__(
6565
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
6666
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
6767
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
68-
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
69-
Defaults to False.
68+
cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
69+
Defaults to Device.AUTO.
7070
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
71-
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
71+
workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
72+
with `providers`. Defaults to None.
7273
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
7374
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
7475
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.

fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import asdict
33

44
from fastembed.common import OnnxProvider, ImageInput
5-
from fastembed.common.types import NumpyArray
5+
from fastembed.common.types import NumpyArray, Device
66
from fastembed.late_interaction_multimodal.colpali import ColPali
77

88
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
@@ -57,7 +57,7 @@ def __init__(
5757
cache_dir: Optional[str] = None,
5858
threads: Optional[int] = None,
5959
providers: Optional[Sequence[OnnxProvider]] = None,
60-
cuda: bool = False,
60+
cuda: Union[bool, Device] = Device.AUTO,
6161
device_ids: Optional[list[int]] = None,
6262
lazy_load: bool = False,
6363
**kwargs: Any,

fastembed/late_interaction_multimodal/onnx_multimodal_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from fastembed.common import OnnxProvider, ImageInput
1212
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
1313
from fastembed.common.preprocessor_utils import load_tokenizer, load_preprocessor
14-
from fastembed.common.types import NumpyArray
14+
from fastembed.common.types import NumpyArray, Device
1515
from fastembed.common.utils import iter_batch
1616
from fastembed.image.transform.operators import Compose
1717
from fastembed.parallel_processor import ParallelWorkerPool
@@ -62,7 +62,7 @@ def _load_onnx_model(
6262
model_file: str,
6363
threads: Optional[int],
6464
providers: Optional[Sequence[OnnxProvider]] = None,
65-
cuda: bool = False,
65+
cuda: Union[bool, Device] = Device.AUTO,
6666
device_id: Optional[int] = None,
6767
extra_session_options: Optional[dict[str, Any]] = None,
6868
) -> None:
@@ -120,7 +120,7 @@ def _embed_documents(
120120
batch_size: int = 256,
121121
parallel: Optional[int] = None,
122122
providers: Optional[Sequence[OnnxProvider]] = None,
123-
cuda: bool = False,
123+
cuda: Union[bool, Device] = Device.AUTO,
124124
device_ids: Optional[list[int]] = None,
125125
local_files_only: bool = False,
126126
specific_model_path: Optional[str] = None,
@@ -191,7 +191,7 @@ def _embed_images(
191191
batch_size: int = 256,
192192
parallel: Optional[int] = None,
193193
providers: Optional[Sequence[OnnxProvider]] = None,
194-
cuda: bool = False,
194+
cuda: Union[bool, Device] = Device.AUTO,
195195
device_ids: Optional[list[int]] = None,
196196
local_files_only: bool = False,
197197
specific_model_path: Optional[str] = None,

0 commit comments

Comments
 (0)