Skip to content

Commit a5b266e

Browse files
Image type hints (#457)
* chore: Added type hints * new: Add type hints for parallel processor * new: Add image type hints * fix: NdArray -> NumpyArray * fix: remove redundant property * refactoring: remove redundant new lines * refactoring: remove redundant new line * fix: fix image input types * fix: remove redundant import * fix: remove mp subscriptions due to mac os issues * chore: Update type hints * chore: Added type gints for functional * refactor --------- Co-authored-by: George Panchuk <george.panchuk@qdrant.tech>
1 parent 877d963 commit a5b266e

File tree

9 files changed

+90
-93
lines changed

9 files changed

+90
-93
lines changed

fastembed/common/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from fastembed.common.types import ImageInput, OnnxProvider, PathInput, PilInput
1+
from fastembed.common.types import ImageInput, OnnxProvider, PathInput
22

3-
__all__ = ["OnnxProvider", "ImageInput", "PathInput", "PilInput"]
3+
__all__ = ["OnnxProvider", "ImageInput", "PathInput"]

fastembed/common/model_management.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def download_files_from_huggingface(
114114
extra_patterns (list[str]): extra patterns to allow in the snapshot download, typically
115115
includes the required model files.
116116
local_files_only (bool, optional): Whether to only use local files. Defaults to False.
117-
specific_model_path (Optional[str], optional): The path to the model dir already pooled from external source
118117
Returns:
119118
Path: The path to the model directory.
120119
"""
@@ -161,9 +160,7 @@ def _collect_file_metadata(
161160
}
162161
return meta
163162

164-
def _save_file_metadata(
165-
model_dir: Path, meta: dict[str, dict[str, Union[int, str]]]
166-
) -> None:
163+
def _save_file_metadata(model_dir: Path, meta: dict[str, dict[str, Union[int, str]]]) -> None:
167164
try:
168165
if not model_dir.exists():
169166
model_dir.mkdir(parents=True, exist_ok=True)

fastembed/common/types.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from pathlib import Path
22
import sys
33
from PIL import Image
4-
from typing import Any, Iterable, Union
5-
4+
from typing import Any, Union
65
import numpy as np
76
from numpy.typing import NDArray
87

@@ -13,11 +12,9 @@
1312

1413

1514
PathInput: TypeAlias = Union[str, Path]
16-
PilInput: TypeAlias = Union[Image.Image, Iterable[Image.Image]]
17-
ImageInput: TypeAlias = Union[PathInput, Iterable[PathInput], PilInput]
15+
ImageInput: TypeAlias = Union[PathInput, Image.Image]
1816

1917
OnnxProvider: TypeAlias = Union[str, tuple[str, dict[Any, Any]]]
20-
2118
NumpyArray = Union[
2219
NDArray[np.float32],
2320
NDArray[np.float16],

fastembed/image/image_embedding.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import Any, Iterable, Optional, Sequence, Type
2-
3-
import numpy as np
1+
from typing import Any, Iterable, Optional, Sequence, Type, Union
42

3+
from fastembed.common.types import NumpyArray
54
from fastembed.common import ImageInput, OnnxProvider
65
from fastembed.image.image_embedding_base import ImageEmbeddingBase
76
from fastembed.image.onnx_embedding import OnnxImageEmbedding
@@ -35,7 +34,7 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
3534
]
3635
```
3736
"""
38-
result = []
37+
result: list[dict[str, Any]] = []
3938
for embedding in cls.EMBEDDINGS_REGISTRY:
4039
result.extend(embedding.list_supported_models())
4140
return result
@@ -74,11 +73,11 @@ def __init__(
7473

7574
def embed(
7675
self,
77-
images: ImageInput,
76+
images: Union[ImageInput, Iterable[ImageInput]],
7877
batch_size: int = 16,
7978
parallel: Optional[int] = None,
8079
**kwargs: Any,
81-
) -> Iterable[np.ndarray]:
80+
) -> Iterable[NumpyArray]:
8281
"""
8382
Encode a list of documents into list of embeddings.
8483
We use mean pooling with attention so that the model can handle variable-length inputs.

fastembed/image/image_embedding_base.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import Iterable, Optional, Any
2-
3-
import numpy as np
1+
from typing import Iterable, Optional, Any, Union
42

3+
from fastembed.common.types import NumpyArray
54
from fastembed.common.model_management import ModelManagement
65
from fastembed.common.types import ImageInput
76

@@ -21,11 +20,11 @@ def __init__(
2120

2221
def embed(
2322
self,
24-
images: ImageInput,
23+
images: Union[ImageInput, Iterable[ImageInput]],
2524
batch_size: int = 16,
2625
parallel: Optional[int] = None,
2726
**kwargs: Any,
28-
) -> Iterable[np.ndarray]:
27+
) -> Iterable[NumpyArray]:
2928
"""
3029
Embeds a list of images into a list of embeddings.
3130
@@ -39,6 +38,6 @@ def embed(
3938
**kwargs: Additional keyword argument to pass to the embed method.
4039
4140
Yields:
42-
Iterable[np.ndarray]: The embeddings.
41+
Iterable[NdArray]: The embeddings.
4342
"""
4443
raise NotImplementedError()

fastembed/image/onnx_embedding.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import Any, Iterable, Optional, Sequence, Type
1+
from typing import Any, Iterable, Optional, Sequence, Type, Union
22

33
import numpy as np
4+
5+
from fastembed.common.types import NumpyArray
46
from fastembed.common import ImageInput, OnnxProvider
57
from fastembed.common.onnx_model import OnnxOutputContext
68
from fastembed.common.utils import define_cache_dir, normalize
@@ -66,7 +68,7 @@
6668
]
6769

6870

69-
class OnnxImageEmbedding(ImageEmbeddingBase, OnnxImageModel[np.ndarray]):
71+
class OnnxImageEmbedding(ImageEmbeddingBase, OnnxImageModel[NumpyArray]):
7072
def __init__(
7173
self,
7274
model_name: str,
@@ -111,15 +113,14 @@ def __init__(
111113
self.cuda = cuda
112114

113115
# This device_id will be used if we need to load model in current process
116+
self.device_id: Optional[int] = None
114117
if device_id is not None:
115118
self.device_id = device_id
116119
elif self.device_ids is not None:
117120
self.device_id = self.device_ids[0]
118-
else:
119-
self.device_id = None
120121

121122
self.model_description = self._get_model_description(model_name)
122-
self.cache_dir = define_cache_dir(cache_dir)
123+
self.cache_dir = str(define_cache_dir(cache_dir))
123124
self._model_dir = self.download_model(
124125
self.model_description,
125126
self.cache_dir,
@@ -155,11 +156,11 @@ def list_supported_models(cls) -> list[dict[str, Any]]:
155156

156157
def embed(
157158
self,
158-
images: ImageInput,
159+
images: Union[ImageInput, Iterable[ImageInput]],
159160
batch_size: int = 16,
160161
parallel: Optional[int] = None,
161162
**kwargs: Any,
162-
) -> Iterable[np.ndarray]:
163+
) -> Iterable[NumpyArray]:
163164
"""
164165
Encode a list of images into list of embeddings.
165166
We use mean pooling with attention so that the model can handle variable-length inputs.
@@ -189,23 +190,23 @@ def embed(
189190
)
190191

191192
@classmethod
192-
def _get_worker_class(cls) -> Type["ImageEmbeddingWorker"]:
193+
def _get_worker_class(cls) -> Type["ImageEmbeddingWorker[NumpyArray]"]:
193194
return OnnxImageEmbeddingWorker
194195

195196
def _preprocess_onnx_input(
196-
self, onnx_input: dict[str, np.ndarray], **kwargs: Any
197-
) -> dict[str, np.ndarray]:
197+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
198+
) -> dict[str, NumpyArray]:
198199
"""
199200
Preprocess the onnx input.
200201
"""
201202

202203
return onnx_input
203204

204-
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.ndarray]:
205+
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[NumpyArray]:
205206
return normalize(output.model_output).astype(np.float32)
206207

207208

208-
class OnnxImageEmbeddingWorker(ImageEmbeddingWorker):
209+
class OnnxImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]):
209210
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> OnnxImageEmbedding:
210211
return OnnxImageEmbedding(
211212
model_name=model_name,

fastembed/image/onnx_image_model.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import os
33
from multiprocessing import get_all_start_methods
44
from pathlib import Path
5-
from typing import Any, Iterable, Optional, Sequence, Type
5+
from typing import Any, Iterable, Optional, Sequence, Type, Union
66

77
import numpy as np
88
from PIL import Image
99

10+
from fastembed.image.transform.operators import Compose
11+
from fastembed.common.types import NumpyArray
1012
from fastembed.common import ImageInput, OnnxProvider
1113
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
1214
from fastembed.common.preprocessor_utils import load_preprocessor
@@ -18,19 +20,19 @@
1820

1921
class OnnxImageModel(OnnxModel[T]):
2022
@classmethod
21-
def _get_worker_class(cls) -> Type["ImageEmbeddingWorker"]:
23+
def _get_worker_class(cls) -> Type["ImageEmbeddingWorker[T]"]:
2224
raise NotImplementedError("Subclasses must implement this method")
2325

2426
def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[T]:
2527
raise NotImplementedError("Subclasses must implement this method")
2628

27-
def __init__(self):
29+
def __init__(self) -> None:
2830
super().__init__()
29-
self.processor = None
31+
self.processor: Optional[Compose] = None
3032

3133
def _preprocess_onnx_input(
32-
self, onnx_input: dict[str, np.ndarray], **kwargs: Any
33-
) -> dict[str, np.ndarray]:
34+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
35+
) -> dict[str, NumpyArray]:
3436
"""
3537
Preprocess the onnx input.
3638
"""
@@ -58,16 +60,18 @@ def _load_onnx_model(
5860
def load_onnx_model(self) -> None:
5961
raise NotImplementedError("Subclasses must implement this method")
6062

61-
def _build_onnx_input(self, encoded: np.ndarray) -> dict[str, np.ndarray]:
62-
return {node.name: encoded for node in self.model.get_inputs()}
63+
def _build_onnx_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]:
64+
input_name = self.model.get_inputs()[0].name
65+
return {input_name: encoded}
6366

6467
def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
6568
with contextlib.ExitStack():
6669
image_files = [
6770
Image.open(image) if not isinstance(image, Image.Image) else image
6871
for image in images
6972
]
70-
encoded = self.processor(image_files)
73+
assert self.processor is not None, "Processor is not initialized"
74+
encoded = np.array(self.processor(image_files))
7175
onnx_input = self._build_onnx_input(encoded)
7276
onnx_input = self._preprocess_onnx_input(onnx_input)
7377
model_output = self.model.run(None, onnx_input)
@@ -78,7 +82,7 @@ def _embed_images(
7882
self,
7983
model_name: str,
8084
cache_dir: str,
81-
images: ImageInput,
85+
images: Union[ImageInput, Iterable[ImageInput]],
8286
batch_size: int = 256,
8387
parallel: Optional[int] = None,
8488
providers: Optional[Sequence[OnnxProvider]] = None,
@@ -124,7 +128,7 @@ def _embed_images(
124128
yield from self._post_process_onnx_output(batch)
125129

126130

127-
class ImageEmbeddingWorker(EmbeddingWorker):
131+
class ImageEmbeddingWorker(EmbeddingWorker[T]):
128132
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
129133
for idx, batch in items:
130134
embeddings = self.model.onnx_embed(batch)

fastembed/image/transform/functional.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from typing import Sized, Union
1+
from typing import Union
22

33
import numpy as np
44
from PIL import Image
55

6+
from fastembed.common.types import NumpyArray
7+
68

79
def convert_to_rgb(image: Image.Image) -> Image.Image:
810
if image.mode == "RGB":
@@ -13,9 +15,9 @@ def convert_to_rgb(image: Image.Image) -> Image.Image:
1315

1416

1517
def center_crop(
16-
image: Union[Image.Image, np.ndarray],
18+
image: Union[Image.Image, NumpyArray],
1719
size: tuple[int, int],
18-
) -> np.ndarray:
20+
) -> NumpyArray:
1921
if isinstance(image, np.ndarray):
2022
_, orig_height, orig_width = image.shape
2123
else:
@@ -40,7 +42,7 @@ def center_crop(
4042
new_height = max(crop_height, orig_height)
4143
new_width = max(crop_width, orig_width)
4244
new_shape = image.shape[:-2] + (new_height, new_width)
43-
new_image = np.zeros_like(image, shape=new_shape)
45+
new_image = np.zeros_like(image, shape=new_shape, dtype=np.float32)
4446

4547
top_pad = (new_height - orig_height) // 2
4648
bottom_pad = top_pad + orig_height
@@ -61,37 +63,34 @@ def center_crop(
6163

6264

6365
def normalize(
64-
image: np.ndarray,
65-
mean: Union[float, np.ndarray],
66-
std: Union[float, np.ndarray],
67-
) -> np.ndarray:
68-
if not isinstance(image, np.ndarray):
69-
raise ValueError("image must be a numpy array")
70-
66+
image: NumpyArray,
67+
mean: Union[float, list[float]],
68+
std: Union[float, list[float]],
69+
) -> NumpyArray:
7170
num_channels = image.shape[1] if len(image.shape) == 4 else image.shape[0]
7271

7372
if not np.issubdtype(image.dtype, np.floating):
7473
image = image.astype(np.float32)
7574

76-
if isinstance(mean, Sized):
77-
if len(mean) != num_channels:
78-
raise ValueError(
79-
f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}"
80-
)
81-
else:
82-
mean = [mean] * num_channels
83-
mean = np.array(mean, dtype=image.dtype)
84-
85-
if isinstance(std, Sized):
86-
if len(std) != num_channels:
87-
raise ValueError(
88-
f"std must have {num_channels} elements if it is an iterable, got {len(std)}"
89-
)
90-
else:
91-
std = [std] * num_channels
92-
std = np.array(std, dtype=image.dtype)
75+
mean = mean if isinstance(mean, list) else [mean] * num_channels
76+
77+
if len(mean) != num_channels:
78+
raise ValueError(
79+
f"mean must have the same number of channels as the image, image has {num_channels} channels, got "
80+
f"{len(mean)}"
81+
)
82+
83+
mean_arr = np.array(mean, dtype=np.float32)
84+
85+
std = std if isinstance(std, list) else [std] * num_channels
86+
if len(std) != num_channels:
87+
raise ValueError(
88+
f"std must have the same number of channels as the image, image has {num_channels} channels, got {len(std)}"
89+
)
90+
91+
std_arr = np.array(std, dtype=np.float32)
9392

94-
image = ((image.T - mean) / std).T
93+
image = ((image.T - mean_arr) / std_arr).T
9594
return image
9695

9796

@@ -114,11 +113,11 @@ def resize(
114113
return image.resize(new_size, resample)
115114

116115

117-
def rescale(image: np.ndarray, scale: float, dtype: type = np.float32) -> np.ndarray:
116+
def rescale(image: NumpyArray, scale: float, dtype: type = np.float32) -> NumpyArray:
118117
return (image * scale).astype(dtype)
119118

120119

121-
def pil2ndarray(image: Union[Image.Image, np.ndarray]) -> np.ndarray:
120+
def pil2ndarray(image: Union[Image.Image, NumpyArray]) -> NumpyArray:
122121
if isinstance(image, Image.Image):
123122
return np.asarray(image).transpose((2, 0, 1))
124123
return image

0 commit comments

Comments
 (0)