Skip to content

Commit f6b9d8b

Browse files
authored
Merge pull request #1621 from xiayouran/dspy_fix
Fix: 'LM' object has no attribute 'copy'
2 parents 08c7b9a + 0840f8a commit f6b9d8b

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

dsp/modules/sentence_vectorizer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import numpy as np
55
import openai
6+
import math
7+
import requests
68

79

810
class BaseSentenceVectorizer(abc.ABC):
@@ -306,3 +308,54 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
306308

307309
embeddings = np.array(embedding_list, dtype=np.float32)
308310
return embeddings
311+
312+
313+
class TEIVectorizer(BaseSentenceVectorizer):
314+
"""The TEIVectorizer class utilizes the TEI(Text Embeddings Inference) Embeddings API to
315+
convert text into embeddings.
316+
317+
For detailed information on the supported models, visit: https://github.com/huggingface/text-embeddings-inference.
318+
319+
`model` is embedding model name.
320+
`embed_batch_size` is the maximum batch size for a single request.
321+
`api_key` request authorization.
322+
`api_url` custom inference endpoint url.
323+
324+
To learn more about getting started with TEI, visit: https://github.com/huggingface/text-embeddings-inference.
325+
"""
326+
327+
def __init__(
328+
self,
329+
model: Optional[str] = "bge-base-en-v1.5",
330+
embed_batch_size: int = 256,
331+
api_key: Optional[str] = None,
332+
api_url: str = "",
333+
):
334+
self.model = model
335+
self.embed_batch_size = embed_batch_size
336+
self.api_key = api_key
337+
self.api_url = api_url
338+
339+
@property
340+
def _headers(self) -> dict:
341+
return {"Authorization": f"Bearer {self.api_key}"}
342+
343+
def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
344+
text_to_vectorize = self._extract_text_from_examples(inp_examples)
345+
embeddings_list = []
346+
347+
n = math.ceil(len(text_to_vectorize) / self.embed_batch_size)
348+
for i in range(n):
349+
response = requests.post(
350+
self.api_url,
351+
headers=self._headers,
352+
json={
353+
"inputs": text_to_vectorize[i * self.embed_batch_size:(i + 1) * self.embed_batch_size],
354+
"normalize": True,
355+
"truncate": True
356+
},
357+
)
358+
embeddings_list.extend(response.json())
359+
360+
embeddings = np.array(embeddings_list, dtype=np.float32)
361+
return embeddings

dspy/clients/lm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def __call__(self, prompt=None, messages=None, **kwargs):
6363
def inspect_history(self, n: int = 1):
6464
_inspect_history(self, n)
6565

66+
def copy(self, **kwargs):
67+
"""Returns a copy of the language model with the same parameters."""
68+
kwargs = {**self.__dict__, **kwargs}
69+
return self.__class__(**kwargs)
70+
6671

6772
@functools.lru_cache(maxsize=None)
6873
def cached_litellm_completion(request):

0 commit comments

Comments
 (0)