Skip to content

Commit 0d609d0

Browse files
committed
fix: fix types
1 parent 2d254c0 commit 0d609d0

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

fastembed/common/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Iterable, Optional, TypeVar
99

1010
import numpy as np
11+
from numpy.typing import NDArray
1112

1213
from fastembed.common.types import NumpyArray
1314

@@ -22,10 +23,9 @@ def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e
2223
return normalized_array
2324

2425

25-
def mean_pooling(input_array: NumpyArray, attention_mask: NumpyArray) -> NumpyArray:
26-
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
26+
def mean_pooling(input_array: NumpyArray, attention_mask: NDArray[np.int64]) -> NumpyArray:
27+
input_mask_expanded = np.expand_dims(attention_mask, axis=-1).astype(np.int64)
2728
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, input_array.shape[-1]))
28-
input_mask_expanded = input_mask_expanded.astype(np.float32)
2929
sum_embeddings = np.sum(input_array * input_mask_expanded, axis=1)
3030
sum_mask = np.sum(input_mask_expanded, axis=1)
3131
pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)

fastembed/text/custom_text_embedding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from typing import Optional, Sequence, Any, Iterable
22

33
from dataclasses import dataclass
4+
5+
import numpy as np
6+
from numpy.typing import NDArray
7+
48
from fastembed.common import OnnxProvider
59
from fastembed.common.model_description import (
610
PoolingType,
@@ -58,7 +62,7 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[Numpy
5862
return self._normalize(self._pool(output.model_output, output.attention_mask))
5963

6064
def _pool(
61-
self, embeddings: NumpyArray, attention_mask: Optional[NumpyArray] = None
65+
self, embeddings: NumpyArray, attention_mask: Optional[NDArray[np.int64]] = None
6266
) -> NumpyArray:
6367
if self._pooling == PoolingType.CLS:
6468
return embeddings[:, 0]

fastembed/text/pooled_embedding.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Iterable, Type
22

33
import numpy as np
4+
from numpy.typing import NDArray
45

56
from fastembed.common.types import NumpyArray
67
from fastembed.common.onnx_model import OnnxOutputContext
@@ -94,10 +95,12 @@ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
9495
return PooledEmbeddingWorker
9596

9697
@classmethod
97-
def mean_pooling(cls, model_output: NumpyArray, attention_mask: NumpyArray) -> NumpyArray:
98-
token_embeddings = model_output.astype(np.float32)
99-
attention_mask = attention_mask.astype(np.float32)
100-
return mean_pooling(token_embeddings, attention_mask)
98+
def mean_pooling(
99+
cls, model_output: NumpyArray, attention_mask: NDArray[np.int64]
100+
) -> NumpyArray:
101+
# token_embeddings = model_output.astype(np.float32)
102+
# attention_mask = attention_mask.astype(np.float32)
103+
return mean_pooling(model_output, attention_mask)
101104

102105
@classmethod
103106
def _list_supported_models(cls) -> list[DenseModelDescription]:

0 commit comments

Comments
 (0)