|
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 | import openai |
| 6 | +import math |
| 7 | +import requests |
6 | 8 |
|
7 | 9 |
|
8 | 10 | class BaseSentenceVectorizer(abc.ABC): |
@@ -306,3 +308,54 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray: |
306 | 308 |
|
307 | 309 | embeddings = np.array(embedding_list, dtype=np.float32) |
308 | 310 | return embeddings |
| 311 | + |
| 312 | + |
| 313 | +class TEIVectorizer(BaseSentenceVectorizer): |
| 314 | + """The TEIVectorizer class utilizes the TEI(Text Embeddings Inference) Embeddings API to |
| 315 | + convert text into embeddings. |
| 316 | +
|
| 317 | + For detailed information on the supported models, visit: https://github.com/huggingface/text-embeddings-inference. |
| 318 | +
|
| 319 | + `model` is embedding model name. |
| 320 | + `embed_batch_size` is the maximum batch size for a single request. |
| 321 | + `api_key` request authorization. |
| 322 | + `api_url` custom inference endpoint url. |
| 323 | +
|
| 324 | + To learn more about getting started with TEI, visit: https://github.com/huggingface/text-embeddings-inference. |
| 325 | + """ |
| 326 | + |
| 327 | + def __init__( |
| 328 | + self, |
| 329 | + model: Optional[str] = "bge-base-en-v1.5", |
| 330 | + embed_batch_size: int = 256, |
| 331 | + api_key: Optional[str] = None, |
| 332 | + api_url: str = "", |
| 333 | + ): |
| 334 | + self.model = model |
| 335 | + self.embed_batch_size = embed_batch_size |
| 336 | + self.api_key = api_key |
| 337 | + self.api_url = api_url |
| 338 | + |
| 339 | + @property |
| 340 | + def _headers(self) -> dict: |
| 341 | + return {"Authorization": f"Bearer {self.api_key}"} |
| 342 | + |
| 343 | + def __call__(self, inp_examples: List["Example"]) -> np.ndarray: |
| 344 | + text_to_vectorize = self._extract_text_from_examples(inp_examples) |
| 345 | + embeddings_list = [] |
| 346 | + |
| 347 | + n = math.ceil(len(text_to_vectorize) / self.embed_batch_size) |
| 348 | + for i in range(n): |
| 349 | + response = requests.post( |
| 350 | + self.api_url, |
| 351 | + headers=self._headers, |
| 352 | + json={ |
| 353 | + "inputs": text_to_vectorize[i * self.embed_batch_size:(i + 1) * self.embed_batch_size], |
| 354 | + "normalize": True, |
| 355 | + "truncate": True |
| 356 | + }, |
| 357 | + ) |
| 358 | + embeddings_list.extend(response.json()) |
| 359 | + |
| 360 | + embeddings = np.array(embeddings_list, dtype=np.float32) |
| 361 | + return embeddings |
0 commit comments