Skip to content

Commit 1a577f8

Browse files
authored
Introduce Embeddings index, CompleteAndGrounded metric, Unbatchify utils (#1843)
* Introduce Embeddings (faiss NN index), CompleteAndGrounded metric, and Unbatchify utils * adjust faiss import * adjust tests * adjust to dspy.Embedder
1 parent 44b3331 commit 1a577f8

File tree

9 files changed

+364
-52
lines changed

9 files changed

+364
-52
lines changed

dspy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from .retrieve import *
77
from .signatures import *
88

9+
import dspy.retrievers
10+
911
# Functional must be imported after primitives, predict and signatures
1012
from .functional import * # isort: skip
1113
from dspy.evaluate import Evaluate # isort: skip

dspy/clients/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .lm import LM
22
from .provider import Provider, TrainingJob
33
from .base_lm import BaseLM, inspect_history
4-
from .embedding import Embedding
4+
from .embedding import Embedder
55
import litellm
66
import os
77
from pathlib import Path

dspy/clients/embedding.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
import numpy as np
33

44

5-
class Embedding:
5+
class Embedder:
66
"""DSPy embedding class.
77
88
The class for computing embeddings for text inputs. This class provides a unified interface for both:
99
1010
1. Hosted embedding models (e.g. OpenAI's text-embedding-3-small) via litellm integration
1111
2. Custom embedding functions that you provide
1212
13-
For hosted models, simply pass the model name as a string (e.g. "openai/text-embedding-3-small"). The class will use
13+
For hosted models, simply pass the model name as a string (e.g., "openai/text-embedding-3-small"). The class will use
1414
litellm to handle the API calls and caching.
1515
1616
For custom embedding models, pass a callable function that:
@@ -24,14 +24,17 @@ class Embedding:
2424
model: The embedding model to use. This can be either a string (representing the name of the hosted embedding
2525
model, must be an embedding model supported by litellm) or a callable that represents a custom embedding
2626
model.
27+
batch_size (int, optional): The default batch size for processing inputs in batches. Defaults to 200.
28+
caching (bool, optional): Whether to cache the embedding response when using a hosted model. Defaults to True.
29+
**kwargs: Additional default keyword arguments to pass to the embedding model.
2730
2831
Examples:
2932
Example 1: Using a hosted model.
3033
3134
```python
3235
import dspy
3336
34-
embedder = dspy.Embedding("openai/text-embedding-3-small")
37+
embedder = dspy.Embedder("openai/text-embedding-3-small", batch_size=100)
3538
embeddings = embedder(["hello", "world"])
3639
3740
assert embeddings.shape == (2, 1536)
@@ -41,37 +44,78 @@ class Embedding:
4144
4245
```python
4346
import dspy
47+
import numpy as np
4448
4549
def my_embedder(texts):
4650
return np.random.rand(len(texts), 10)
4751
48-
embedder = dspy.Embedding(my_embedder)
49-
embeddings = embedder(["hello", "world"])
52+
embedder = dspy.Embedder(my_embedder)
53+
embeddings = embedder(["hello", "world"], batch_size=1)
5054
5155
assert embeddings.shape == (2, 10)
5256
```
5357
"""
5458

55-
def __init__(self, model):
59+
def __init__(self, model, batch_size=200, caching=True, **kwargs):
5660
self.model = model
61+
self.batch_size = batch_size
62+
self.caching = caching
63+
self.default_kwargs = kwargs
5764

58-
def __call__(self, inputs, caching=True, **kwargs):
65+
def __call__(self, inputs, batch_size=None, caching=None, **kwargs):
5966
"""Compute embeddings for the given inputs.
6067
6168
Args:
6269
inputs: The inputs to compute embeddings for, can be a single string or a list of strings.
63-
caching: Whether to cache the embedding response, only valid when using a hosted embedding model.
64-
kwargs: Additional keyword arguments to pass to the embedding model.
70+
batch_size (int, optional): The batch size for processing inputs. If None, defaults to the batch_size set during initialization.
71+
caching (bool, optional): Whether to cache the embedding response when using a hosted model. If None, defaults to the caching setting from initialization.
72+
**kwargs: Additional keyword arguments to pass to the embedding model. These will override the default kwargs provided during initialization.
6573
6674
Returns:
67-
A 2-D numpy array of embeddings, one embedding per row.
75+
numpy.ndarray: If the input is a single string, returns a 1D numpy array representing the embedding.
76+
If the input is a list of strings, returns a 2D numpy array of embeddings, one embedding per row.
6877
"""
78+
6979
if isinstance(inputs, str):
80+
is_single_input = True
7081
inputs = [inputs]
71-
if isinstance(self.model, str):
72-
embedding_response = litellm.embedding(model=self.model, input=inputs, caching=caching, **kwargs)
73-
return np.array([data["embedding"] for data in embedding_response.data], dtype=np.float32)
74-
elif callable(self.model):
75-
return np.array(self.model(inputs, **kwargs), dtype=np.float32)
7682
else:
77-
raise ValueError(f"`model` in `dspy.Embedding` must be a string or a callable, but got {type(self.model)}.")
83+
is_single_input = False
84+
85+
assert all(isinstance(inp, str) for inp in inputs), "All inputs must be strings."
86+
87+
if batch_size is None:
88+
batch_size = self.batch_size
89+
if caching is None:
90+
caching = self.caching
91+
92+
merged_kwargs = self.default_kwargs.copy()
93+
merged_kwargs.update(kwargs)
94+
95+
embeddings_list = []
96+
97+
def chunk(inputs_list, size):
98+
for i in range(0, len(inputs_list), size):
99+
yield inputs_list[i : i + size]
100+
101+
for batch_inputs in chunk(inputs, batch_size):
102+
if isinstance(self.model, str):
103+
embedding_response = litellm.embedding(
104+
model=self.model, input=batch_inputs, caching=caching, **merged_kwargs
105+
)
106+
batch_embeddings = [data["embedding"] for data in embedding_response.data]
107+
elif callable(self.model):
108+
batch_embeddings = self.model(batch_inputs, **merged_kwargs)
109+
else:
110+
raise ValueError(
111+
f"`model` in `dspy.Embedder` must be a string or a callable, but got {type(self.model)}."
112+
)
113+
114+
embeddings_list.extend(batch_embeddings)
115+
116+
embeddings = np.array(embeddings_list, dtype=np.float32)
117+
118+
if is_single_input:
119+
return embeddings[0]
120+
else:
121+
return embeddings

dspy/evaluate/auto_evaluation.py

Lines changed: 100 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,35 @@ class SemanticRecallPrecision(dspy.Signature):
1414
precision: float = dspy.OutputField(desc="fraction (out of 1.0) of system response covered by the ground truth")
1515

1616

17+
class DecompositionalSemanticRecallPrecision(dspy.Signature):
18+
"""
19+
Compare a system's response to the ground truth to compute recall and precision of key ideas.
20+
You will first enumerate key ideas in each response, discuss their overlap, and then report recall and precision.
21+
"""
22+
23+
question: str = dspy.InputField()
24+
ground_truth: str = dspy.InputField()
25+
system_response: str = dspy.InputField()
26+
ground_truth_key_ideas: str = dspy.OutputField(desc="enumeration of key ideas in the ground truth")
27+
system_response_key_ideas: str = dspy.OutputField(desc="enumeration of key ideas in the system response")
28+
discussion: str = dspy.OutputField(desc="discussion of the overlap between ground truth and system response")
29+
recall: float = dspy.OutputField(desc="fraction (out of 1.0) of ground truth covered by the system response")
30+
precision: float = dspy.OutputField(desc="fraction (out of 1.0) of system response covered by the ground truth")
31+
32+
1733
def f1_score(precision, recall):
34+
precision, recall = max(0.0, min(1.0, precision)), max(0.0, min(1.0, recall))
1835
return 0.0 if precision + recall == 0 else 2 * (precision * recall) / (precision + recall)
1936

2037

2138
class SemanticF1(dspy.Module):
22-
def __init__(self, threshold=0.66):
39+
def __init__(self, threshold=0.66, decompositional=False):
2340
self.threshold = threshold
24-
self.module = dspy.ChainOfThought(SemanticRecallPrecision)
41+
42+
if decompositional:
43+
self.module = dspy.ChainOfThought(DecompositionalSemanticRecallPrecision)
44+
else:
45+
self.module = dspy.ChainOfThought(SemanticRecallPrecision)
2546

2647
def forward(self, example, pred, trace=None):
2748
scores = self.module(question=example.question, ground_truth=example.response, system_response=pred.response)
@@ -30,42 +51,92 @@ def forward(self, example, pred, trace=None):
3051
return score if trace is None else score >= self.threshold
3152

3253

33-
"""
34-
Soon-to-be deprecated Signatures & Modules Below.
35-
"""
54+
55+
###########
56+
57+
58+
class DecompositionalSemanticRecall(dspy.Signature):
59+
"""
60+
Estimate the completeness of a system's responses, against the ground truth.
61+
You will first enumerate key ideas in each response, discuss their overlap, and then report completeness.
62+
"""
63+
64+
question: str = dspy.InputField()
65+
ground_truth: str = dspy.InputField()
66+
system_response: str = dspy.InputField()
67+
ground_truth_key_ideas: str = dspy.OutputField(desc="enumeration of key ideas in the ground truth")
68+
system_response_key_ideas: str = dspy.OutputField(desc="enumeration of key ideas in the system response")
69+
discussion: str = dspy.OutputField(desc="discussion of the overlap between ground truth and system response")
70+
completeness: float = dspy.OutputField(desc="fraction (out of 1.0) of ground truth covered by the system response")
71+
72+
73+
74+
class DecompositionalGroundedness(dspy.Signature):
75+
"""
76+
Estimate the groundedness of a system's responses, against real retrieved documents written by people.
77+
You will first enumerate whatever non-trivial or check-worthy claims are made in the system response, and then
78+
discuss the extent to which some or all of them can be deduced from the retrieved context and basic commonsense.
79+
"""
80+
81+
question: str = dspy.InputField()
82+
retrieved_context: str = dspy.InputField()
83+
system_response: str = dspy.InputField()
84+
system_response_claims: str = dspy.OutputField(desc="enumeration of non-trivial or check-worthy claims in the system response")
85+
discussion: str = dspy.OutputField(desc="discussion of how supported the claims are by the retrieved context")
86+
groundedness: float = dspy.OutputField(desc="fraction (out of 1.0) of system response supported by the retrieved context")
87+
88+
89+
class CompleteAndGrounded(dspy.Module):
90+
def __init__(self, threshold=0.66):
91+
self.threshold = threshold
92+
self.completeness_module = dspy.ChainOfThought(DecompositionalSemanticRecall)
93+
self.groundedness_module = dspy.ChainOfThought(DecompositionalGroundedness)
94+
95+
def forward(self, example, pred, trace=None):
96+
completeness = self.completeness_module(question=example.question, ground_truth=example.response, system_response=pred.response)
97+
groundedness = self.groundedness_module(question=example.question, retrieved_context=pred.context, system_response=pred.response)
98+
score = f1_score(groundedness.groundedness, completeness.completeness)
99+
100+
return score if trace is None else score >= self.threshold
101+
102+
103+
104+
# """
105+
# Soon-to-be deprecated Signatures & Modules Below.
106+
# """
36107

37108

38-
class AnswerCorrectnessSignature(dspy.Signature):
39-
"""Verify that the predicted answer matches the gold answer."""
109+
# class AnswerCorrectnessSignature(dspy.Signature):
110+
# """Verify that the predicted answer matches the gold answer."""
40111

41-
question = dspy.InputField()
42-
gold_answer = dspy.InputField(desc="correct answer for question")
43-
predicted_answer = dspy.InputField(desc="predicted answer for question")
44-
is_correct = dspy.OutputField(desc="True or False")
112+
# question = dspy.InputField()
113+
# gold_answer = dspy.InputField(desc="correct answer for question")
114+
# predicted_answer = dspy.InputField(desc="predicted answer for question")
115+
# is_correct = dspy.OutputField(desc="True or False")
45116

46117

47-
class AnswerCorrectness(dspy.Module):
48-
def __init__(self):
49-
super().__init__()
50-
self.evaluate_correctness = dspy.ChainOfThought(AnswerCorrectnessSignature)
118+
# class AnswerCorrectness(dspy.Module):
119+
# def __init__(self):
120+
# super().__init__()
121+
# self.evaluate_correctness = dspy.ChainOfThought(AnswerCorrectnessSignature)
51122

52-
def forward(self, question, gold_answer, predicted_answer):
53-
return self.evaluate_correctness(question=question, gold_answer=gold_answer, predicted_answer=predicted_answer)
123+
# def forward(self, question, gold_answer, predicted_answer):
124+
# return self.evaluate_correctness(question=question, gold_answer=gold_answer, predicted_answer=predicted_answer)
54125

55126

56-
class AnswerFaithfulnessSignature(dspy.Signature):
57-
"""Verify that the predicted answer is based on the provided context."""
127+
# class AnswerFaithfulnessSignature(dspy.Signature):
128+
# """Verify that the predicted answer is based on the provided context."""
58129

59-
context = dspy.InputField(desc="relevant facts for producing answer")
60-
question = dspy.InputField()
61-
answer = dspy.InputField(desc="often between 1 and 5 words")
62-
is_faithful = dspy.OutputField(desc="True or False")
130+
# context = dspy.InputField(desc="relevant facts for producing answer")
131+
# question = dspy.InputField()
132+
# answer = dspy.InputField(desc="often between 1 and 5 words")
133+
# is_faithful = dspy.OutputField(desc="True or False")
63134

64135

65-
class AnswerFaithfulness(dspy.Module):
66-
def __init__(self):
67-
super().__init__()
68-
self.evaluate_faithfulness = dspy.ChainOfThought(AnswerFaithfulnessSignature)
136+
# class AnswerFaithfulness(dspy.Module):
137+
# def __init__(self):
138+
# super().__init__()
139+
# self.evaluate_faithfulness = dspy.ChainOfThought(AnswerFaithfulnessSignature)
69140

70-
def forward(self, context, question, answer):
71-
return self.evaluate_faithfulness(context=context, question=question, answer=answer)
141+
# def forward(self, context, question, answer):
142+
# return self.evaluate_faithfulness(context=context, question=question, answer=answer)

dspy/predict/knn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self, k: int, trainset: List[dsp.Example], vectorizer=None):
1313
Args:
1414
k: Number of nearest neighbors to retrieve
1515
trainset: List of training examples to search through
16-
vectorizer: Optional dspy.Embedding for computing embeddings. If None, uses sentence-transformers.
16+
vectorizer: Optional dspy.Embedder for computing embeddings. If None, uses sentence-transformers.
1717
1818
Example:
1919
>>> trainset = [dsp.Example(input="hello", output="world"), ...]
@@ -24,7 +24,7 @@ def __init__(self, k: int, trainset: List[dsp.Example], vectorizer=None):
2424

2525
self.k = k
2626
self.trainset = trainset
27-
self.embedding = vectorizer or dspy.Embedding(dsp.SentenceTransformersVectorizer())
27+
self.embedding = vectorizer or dspy.Embedder(dsp.SentenceTransformersVectorizer())
2828
trainset_casted_to_vectorize = [
2929
" | ".join([f"{key}: {value}" for key, value in example.items() if key in example._input_keys])
3030
for example in self.trainset

dspy/retrievers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .embeddings import Embeddings

0 commit comments

Comments
 (0)