Skip to content

Commit 993dcd5

Browse files
chore: Add missing type hints in functions (#453)
* chore: Add missing type hints in functions * add missing import, small type refactor --------- Co-authored-by: George Panchuk <[email protected]>
1 parent 73e1e5e commit 993dcd5

13 files changed

+41
-33
lines changed

fastembed/common/model_management.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import shutil
55
import tarfile
66
from pathlib import Path
7-
from typing import Any
7+
from typing import Any, Optional
88

99
import requests
1010
from huggingface_hub import snapshot_download, model_info, list_repo_tree

fastembed/common/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os
1+
from pathlib import Path
22
import sys
33
from PIL import Image
44
from typing import Any, Iterable, Union
@@ -9,7 +9,7 @@
99
from typing_extensions import TypeAlias
1010

1111

12-
PathInput: TypeAlias = Union[str, os.PathLike]
12+
PathInput: TypeAlias = Union[str, Path]
1313
PilInput: TypeAlias = Union[Image.Image, Iterable[Image.Image]]
1414
ImageInput: TypeAlias = Union[PathInput, Iterable[PathInput], PilInput]
1515

fastembed/common/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,22 @@
55
import unicodedata
66
from pathlib import Path
77
from itertools import islice
8-
from typing import Generator, Iterable, Optional, Union
8+
from typing import Iterable, Optional, TypeVar
99

1010
import numpy as np
1111

12+
T = TypeVar("T")
1213

13-
def normalize(input_array, p=2, dim=1, eps=1e-12) -> np.ndarray:
14+
15+
def normalize(input_array: np.ndarray, p: int = 2, dim: int = 1, eps: float = 1e-12) -> np.ndarray:
1416
# Calculate the Lp norm along the specified dimension
1517
norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True)
1618
norm = np.maximum(norm, eps) # Avoid division by zero
1719
normalized_array = input_array / norm
1820
return normalized_array
1921

2022

21-
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
23+
def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]:
2224
"""
2325
>>> list(iter_batch([1,2,3,4,5], 3))
2426
[[1, 2, 3], [4, 5]]

fastembed/image/transform/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def resize(
114114
return image.resize(new_size, resample)
115115

116116

117-
def rescale(image: np.ndarray, scale: float, dtype=np.float32) -> np.ndarray:
117+
def rescale(image: np.ndarray, scale: float, dtype: type = np.float32) -> np.ndarray:
118118
return (image * scale).astype(dtype)
119119

120120

tests/test_attention_embeddings.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
11-
def test_attention_embeddings(model_name) -> None:
11+
def test_attention_embeddings(model_name: str) -> None:
1212
is_ci = os.getenv("CI")
1313
model = SparseTextEmbedding(model_name=model_name)
1414

@@ -71,7 +71,7 @@ def test_attention_embeddings(model_name) -> None:
7171

7272

7373
@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
74-
def test_parallel_processing(model_name) -> None:
74+
def test_parallel_processing(model_name: str) -> None:
7575
is_ci = os.getenv("CI")
7676

7777
model = SparseTextEmbedding(model_name=model_name)
@@ -96,7 +96,7 @@ def test_parallel_processing(model_name) -> None:
9696

9797

9898
@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
99-
def test_multilanguage(model_name) -> None:
99+
def test_multilanguage(model_name: str) -> None:
100100
is_ci = os.getenv("CI")
101101

102102
docs = ["Mangez-vous vraiment des grenouilles?", "Je suis au lit"]
@@ -122,7 +122,7 @@ def test_multilanguage(model_name) -> None:
122122

123123

124124
@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
125-
def test_special_characters(model_name) -> None:
125+
def test_special_characters(model_name: str) -> None:
126126
is_ci = os.getenv("CI")
127127

128128
docs = [
@@ -145,7 +145,7 @@ def test_special_characters(model_name) -> None:
145145

146146

147147
@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions"])
148-
def test_lazy_load(model_name) -> None:
148+
def test_lazy_load(model_name: str) -> None:
149149
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
150150
assert not hasattr(model.model, "model")
151151
docs = ["hello world", "flag embedding"]

tests/test_image_onnx_embeddings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_embedding() -> None:
6161

6262

6363
@pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")])
64-
def test_batch_embedding(n_dims, model_name) -> None:
64+
def test_batch_embedding(n_dims: int, model_name: str) -> None:
6565
is_ci = os.getenv("CI")
6666
model = ImageEmbedding(model_name=model_name)
6767
n_images = 32
@@ -81,7 +81,7 @@ def test_batch_embedding(n_dims, model_name) -> None:
8181

8282

8383
@pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")])
84-
def test_parallel_processing(n_dims, model_name) -> None:
84+
def test_parallel_processing(n_dims: int, model_name: str) -> None:
8585
is_ci = os.getenv("CI")
8686
model = ImageEmbedding(model_name=model_name)
8787

@@ -109,7 +109,7 @@ def test_parallel_processing(n_dims, model_name) -> None:
109109

110110

111111
@pytest.mark.parametrize("model_name", ["Qdrant/clip-ViT-B-32-vision"])
112-
def test_lazy_load(model_name) -> None:
112+
def test_lazy_load(model_name: str) -> None:
113113
is_ci = os.getenv("CI")
114114
model = ImageEmbedding(model_name=model_name, lazy_load=True)
115115
assert not hasattr(model.model, "model")

tests/test_late_interaction_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_parallel_processing():
226226
"model_name",
227227
["colbert-ir/colbertv2.0"],
228228
)
229-
def test_lazy_load(model_name):
229+
def test_lazy_load(model_name: str):
230230
is_ci = os.getenv("CI")
231231

232232
model = LateInteractionTextEmbedding(model_name=model_name, lazy_load=True)

tests/test_multi_gpu.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from typing import Optional
23
from fastembed import (
34
TextEmbedding,
45
SparseTextEmbedding,
@@ -13,7 +14,7 @@
1314

1415
@pytest.mark.skip(reason="Requires a multi-gpu server")
1516
@pytest.mark.parametrize("device_id", [None, 0, 1])
16-
def test_gpu_via_providers(device_id) -> None:
17+
def test_gpu_via_providers(device_id: Optional[int]) -> None:
1718
docs = ["hello world", "flag embedding"]
1819

1920
device_id = device_id if device_id is not None else 0
@@ -85,7 +86,7 @@ def test_gpu_via_providers(device_id) -> None:
8586

8687
@pytest.mark.skip(reason="Requires a multi-gpu server")
8788
@pytest.mark.parametrize("device_ids", [None, [0], [1], [0, 1]])
88-
def test_gpu_cuda_device_ids(device_ids) -> None:
89+
def test_gpu_cuda_device_ids(device_ids: Optional[list[int]]) -> None:
8990
docs = ["hello world", "flag embedding"]
9091
device_id = device_ids[0] if device_ids else 0
9192
embedding_model = TextEmbedding(
@@ -170,7 +171,7 @@ def test_gpu_cuda_device_ids(device_ids) -> None:
170171
@pytest.mark.parametrize(
171172
"device_ids,parallel", [(None, None), (None, 2), ([1], None), ([1], 1), ([1], 2), ([0, 1], 2)]
172173
)
173-
def test_multi_gpu_parallel_inference(device_ids, parallel) -> None:
174+
def test_multi_gpu_parallel_inference(device_ids: Optional[list[int]], parallel: int) -> None:
174175
docs = ["hello world", "flag embedding"] * 100
175176
batch_size = 5
176177

tests/test_sparse_embeddings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def bm25_instance() -> None:
119119
delete_model_cache(model._model_dir)
120120

121121

122-
def test_stem_with_stopwords_and_punctuation(bm25_instance) -> None:
122+
def test_stem_with_stopwords_and_punctuation(bm25_instance: Bm25) -> None:
123123
# Setup
124124
bm25_instance.stopwords = {"the", "is", "a"}
125125
bm25_instance.punctuation = {".", ",", "!"}
@@ -135,7 +135,7 @@ def test_stem_with_stopwords_and_punctuation(bm25_instance) -> None:
135135
assert result == expected, f"Expected {expected}, but got {result}"
136136

137137

138-
def test_stem_case_insensitive_stopwords(bm25_instance) -> None:
138+
def test_stem_case_insensitive_stopwords(bm25_instance: Bm25) -> None:
139139
# Setup
140140
bm25_instance.stopwords = {"the", "is", "a"}
141141
bm25_instance.punctuation = {".", ",", "!"}
@@ -152,7 +152,7 @@ def test_stem_case_insensitive_stopwords(bm25_instance) -> None:
152152

153153

154154
@pytest.mark.parametrize("disable_stemmer", [True, False])
155-
def test_disable_stemmer_behavior(disable_stemmer) -> None:
155+
def test_disable_stemmer_behavior(disable_stemmer: bool) -> None:
156156
# Setup
157157
model = Bm25("Qdrant/bm25", language="english", disable_stemmer=disable_stemmer)
158158
model.stopwords = {"the", "is", "a"}
@@ -176,7 +176,7 @@ def test_disable_stemmer_behavior(disable_stemmer) -> None:
176176
"model_name",
177177
["prithivida/Splade_PP_en_v1"],
178178
)
179-
def test_lazy_load(model_name) -> None:
179+
def test_lazy_load(model_name: str) -> None:
180180
is_ci = os.getenv("CI")
181181
model = SparseTextEmbedding(model_name=model_name, lazy_load=True)
182182
assert not hasattr(model.model, "model")

tests/test_text_cross_encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"model_name",
2727
[model_name for model_name in CANONICAL_SCORE_VALUES],
2828
)
29-
def test_rerank(model_name) -> None:
29+
def test_rerank(model_name: str) -> None:
3030
is_ci = os.getenv("CI")
3131

3232
model = TextCrossEncoder(model_name=model_name)
@@ -53,7 +53,7 @@ def test_rerank(model_name) -> None:
5353
"model_name",
5454
[model_name for model_name in SELECTED_MODELS.values()],
5555
)
56-
def test_batch_rerank(model_name) -> None:
56+
def test_batch_rerank(model_name: str) -> None:
5757
is_ci = os.getenv("CI")
5858

5959
model = TextCrossEncoder(model_name=model_name)
@@ -82,7 +82,7 @@ def test_batch_rerank(model_name) -> None:
8282
"model_name",
8383
["Xenova/ms-marco-MiniLM-L-6-v2"],
8484
)
85-
def test_lazy_load(model_name) -> None:
85+
def test_lazy_load(model_name: str) -> None:
8686
is_ci = os.getenv("CI")
8787
model = TextCrossEncoder(model_name=model_name, lazy_load=True)
8888
assert not hasattr(model.model, "model")
@@ -99,7 +99,7 @@ def test_lazy_load(model_name) -> None:
9999
"model_name",
100100
[model_name for model_name in SELECTED_MODELS.values()],
101101
)
102-
def test_rerank_pairs_parallel(model_name) -> None:
102+
def test_rerank_pairs_parallel(model_name: str) -> None:
103103
is_ci = os.getenv("CI")
104104

105105
model = TextCrossEncoder(model_name=model_name)

0 commit comments

Comments
 (0)