11from typing import Any , Iterable , Sequence , Type
22
33
4- from fastembed .common .types import NumpyArray
4+ from fastembed .common .types import NumpyArray , Device
55from fastembed .common import ImageInput , OnnxProvider
66from fastembed .common .onnx_model import OnnxOutputContext
77from fastembed .common .utils import define_cache_dir , normalize
@@ -63,10 +63,11 @@ class OnnxImageEmbedding(ImageEmbeddingBase, OnnxImageModel[NumpyArray]):
6363 def __init__ (
6464 self ,
6565 model_name : str ,
66+
6667 cache_dir : str | None = None ,
6768 threads : int | None = None ,
6869 providers : Sequence [OnnxProvider ] | None = None ,
69- cuda : bool = False ,
70+ cuda : bool | Device = Device . AUTO ,
7071 device_ids : list [int ] | None = None ,
7172 lazy_load : bool = False ,
7273 device_id : int | None = None ,
@@ -82,10 +83,11 @@ def __init__(
8283 threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
8384 providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
8485 Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
85- cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
86- Defaults to False .
86+ cuda (Union[ bool, Device] , optional): Whether to use cuda for inference. Mutually exclusive with `providers`
87+ Defaults to Device.AUTO .
8788 device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
88- workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
89+ workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
90+ with `providers`. Defaults to None.
8991 lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
9092 Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
9193 device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
0 commit comments