Skip to content

Commit 469d037

Browse files
Add docs_uri_column_name to DatabricksRM (#1929)
1 parent 6246301 commit 469d037

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

dspy/retrieve/databricks_rm.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def to_dict(self) -> Dict[str, Any]:
2525
"type": self.type,
2626
}
2727

28+
2829
class DatabricksRM(dspy.Retrieve):
2930
"""
3031
A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k
@@ -89,6 +90,7 @@ def __init__(
8990
filters_json: Optional[str] = None,
9091
k: int = 3,
9192
docs_id_column_name: str = "id",
93+
docs_uri_column_name: Optional[str] = None,
9294
text_column_name: str = "text",
9395
use_with_databricks_agent_framework: bool = False,
9496
):
@@ -113,6 +115,8 @@ def __init__(
113115
k (int): The number of documents to retrieve.
114116
docs_id_column_name (str): The name of the column in the Databricks Vector Search Index
115117
containing document IDs.
118+
docs_uri_column_name (Optional[str]): The name of the column in the Databricks Vector Search Index
119+
containing document URI.
116120
text_column_name (str): The name of the column in the Databricks Vector Search Index
117121
containing document text to retrieve.
118122
use_with_databricks_agent_framework (bool): Whether to use the `DatabricksRM` in a way that is
@@ -135,11 +139,13 @@ def __init__(
135139
self.filters_json = filters_json
136140
self.k = k
137141
self.docs_id_column_name = docs_id_column_name
142+
self.docs_uri_column_name = docs_uri_column_name
138143
self.text_column_name = text_column_name
139144
self.use_with_databricks_agent_framework = use_with_databricks_agent_framework
140145
if self.use_with_databricks_agent_framework:
141146
try:
142147
import mlflow
148+
143149
mlflow.models.set_retriever_schema(
144150
primary_key="doc_id",
145151
text_column="page_content",
@@ -170,9 +176,13 @@ def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]:
170176
Args:
171177
item: Dict[str, Any]: a record from the search results.
172178
Returns:
173-
Dict[str, Any]: Search result column values, excluding the "text" and not "id" columns.
179+
Dict[str, Any]: Search result column values, excluding the "text", "id" and "uri" columns.
174180
"""
175-
extra_columns = {k: v for k, v in item.items() if k not in [self.docs_id_column_name, self.text_column_name]}
181+
extra_columns = {
182+
k: v
183+
for k, v in item.items()
184+
if k not in [self.docs_id_column_name, self.text_column_name, self.docs_uri_column_name]
185+
}
176186
if self.docs_id_column_name == "metadata":
177187
extra_columns = {
178188
**extra_columns,
@@ -273,24 +283,27 @@ def forward(
273283
sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[: self.k]
274284

275285
if self.use_with_databricks_agent_framework:
276-
return [Document(
277-
page_content=doc[self.text_column_name],
278-
metadata={
279-
"doc_id": self._extract_doc_ids(doc),
280-
"doc_uri": f"index/{self.databricks_index_name}/id/{self._extract_doc_ids(doc)}",
281-
}
282-
| self._get_extra_columns(doc),
283-
type="Document",
284-
).to_dict() for doc in sorted_docs]
286+
return [
287+
Document(
288+
page_content=doc[self.text_column_name],
289+
metadata={
290+
"doc_id": self._extract_doc_ids(doc),
291+
"doc_uri": doc[self.docs_uri_column_name],
292+
}
293+
| self._get_extra_columns(doc),
294+
type="Document",
295+
).to_dict()
296+
for doc in sorted_docs
297+
]
285298
else:
286299
# Returning the prediction
287300
return Prediction(
288301
docs=[doc[self.text_column_name] for doc in sorted_docs],
289302
doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs],
303+
doc_uris=[doc[self.docs_uri_column_name] for doc in sorted_docs],
290304
extra_columns=[self._get_extra_columns(item) for item in sorted_docs],
291305
)
292306

293-
294307
@staticmethod
295308
def _query_via_databricks_sdk(
296309
index_name: str,

0 commit comments

Comments
 (0)