Skip to content

Commit 1df86fa

Browse files
xucailiangxucai
andauthored
neo4j_rm support litellm embedding model (#1771)
Co-authored-by: xucai <[email protected]>
1 parent 08d9ece commit 1df86fa

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

dspy/retrieve/neo4j_rm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Any, List, Optional, Union
2+
from typing import Any, List, Optional, Union, Callable
33

44
import backoff
55
from openai import (
@@ -108,6 +108,7 @@ def __init__(
108108
retrieval_query: str = None,
109109
embedding_provider: str = "openai",
110110
embedding_model: str = "text-embedding-ada-002",
111+
embedding_function: Optional[Callable] = None,
111112
):
112113
super().__init__(k=k)
113114
self.index_name = index_name
@@ -136,7 +137,7 @@ def __init__(
136137
) as e:
137138
raise ConnectionError("Failed to connect to Neo4j database") from e
138139

139-
self.embedder = Embedder(provider=embedding_provider, model=embedding_model)
140+
self.embedder = embedding_function or Embedder(provider=embedding_provider, model=embedding_model)
140141

141142
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> Prediction:
142143
if not isinstance(query_or_queries, list):

0 commit comments

Comments
 (0)