Skip to content

Commit 0d9dad7

Browse files
committed
Allow Embeddings to return indices
1 parent 472a9c1 commit 0d9dad7

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

dspy/propose/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def get_dspy_source_code(module):
146146
base_code = ""
147147

148148
# Don't get source code for Predict or ChainOfThought modules (NOTE we will need to extend this list as more DSPy.modules are added)
149+
# TODO: if type(module).__name__ not in ["Predict", "ChainOfThought", "ReAct"]:
149150
if not type(module).__name__ == "Predict" and not type(module).__name__ == "ChainOfThought":
150151
try:
151152
base_code = inspect.getsource(type(module))

dspy/retrievers/embeddings.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def __call__(self, query: str):
3434

3535
def forward(self, query: str):
3636
import dspy
37-
return dspy.Prediction(passages=self.search_fn(query))
37+
38+
passages, indices = self.search_fn(query)
39+
return dspy.Prediction(passages=passages, indices=indices)
3840

3941
def _batch_forward(self, queries: List[str]):
4042
q_embeds = self.embedder(queries)
@@ -76,7 +78,7 @@ def _rerank_and_predict(self, q_embeds: np.ndarray, candidate_indices: np.ndarra
7678
top_k_indices = np.argsort(-scores, axis=1)[:, :self.k]
7779
top_indices = candidate_indices[np.arange(len(q_embeds))[:, None], top_k_indices]
7880

79-
return [[self.corpus[idx] for idx in indices] for indices in top_indices]
81+
return [([self.corpus[idx] for idx in indices], [idx for idx in indices]) for indices in top_indices]
8082

8183
def _normalize(self, embeddings: np.ndarray):
8284
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)

0 commit comments

Comments
 (0)