@@ -25,6 +25,7 @@ def to_dict(self) -> Dict[str, Any]:
2525 "type" : self .type ,
2626 }
2727
28+
2829class DatabricksRM (dspy .Retrieve ):
2930 """
3031 A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k
@@ -89,6 +90,7 @@ def __init__(
8990 filters_json : Optional [str ] = None ,
9091 k : int = 3 ,
9192 docs_id_column_name : str = "id" ,
93+ docs_uri_column_name : Optional [str ] = None ,
9294 text_column_name : str = "text" ,
9395 use_with_databricks_agent_framework : bool = False ,
9496 ):
@@ -113,6 +115,8 @@ def __init__(
113115 k (int): The number of documents to retrieve.
114116 docs_id_column_name (str): The name of the column in the Databricks Vector Search Index
115117 containing document IDs.
118+ docs_uri_column_name (Optional[str]): The name of the column in the Databricks Vector Search Index
119+ containing document URI.
116120 text_column_name (str): The name of the column in the Databricks Vector Search Index
117121 containing document text to retrieve.
118122 use_with_databricks_agent_framework (bool): Whether to use the `DatabricksRM` in a way that is
@@ -135,11 +139,13 @@ def __init__(
135139 self .filters_json = filters_json
136140 self .k = k
137141 self .docs_id_column_name = docs_id_column_name
142+ self .docs_uri_column_name = docs_uri_column_name
138143 self .text_column_name = text_column_name
139144 self .use_with_databricks_agent_framework = use_with_databricks_agent_framework
140145 if self .use_with_databricks_agent_framework :
141146 try :
142147 import mlflow
148+
143149 mlflow .models .set_retriever_schema (
144150 primary_key = "doc_id" ,
145151 text_column = "page_content" ,
@@ -170,9 +176,13 @@ def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]:
170176 Args:
171177 item: Dict[str, Any]: a record from the search results.
172178 Returns:
173- Dict[str, Any]: Search result column values, excluding the "text" and not "id " columns.
179+ Dict[str, Any]: Search result column values, excluding the "text", "id" and "uri " columns.
174180 """
175- extra_columns = {k : v for k , v in item .items () if k not in [self .docs_id_column_name , self .text_column_name ]}
181+ extra_columns = {
182+ k : v
183+ for k , v in item .items ()
184+ if k not in [self .docs_id_column_name , self .text_column_name , self .docs_uri_column_name ]
185+ }
176186 if self .docs_id_column_name == "metadata" :
177187 extra_columns = {
178188 ** extra_columns ,
@@ -273,24 +283,27 @@ def forward(
273283 sorted_docs = sorted (items , key = lambda x : x ["score" ], reverse = True )[: self .k ]
274284
275285 if self .use_with_databricks_agent_framework :
276- return [Document (
277- page_content = doc [self .text_column_name ],
278- metadata = {
279- "doc_id" : self ._extract_doc_ids (doc ),
280- "doc_uri" : f"index/{ self .databricks_index_name } /id/{ self ._extract_doc_ids (doc )} " ,
281- }
282- | self ._get_extra_columns (doc ),
283- type = "Document" ,
284- ).to_dict () for doc in sorted_docs ]
286+ return [
287+ Document (
288+ page_content = doc [self .text_column_name ],
289+ metadata = {
290+ "doc_id" : self ._extract_doc_ids (doc ),
291+ "doc_uri" : doc [self .docs_uri_column_name ],
292+ }
293+ | self ._get_extra_columns (doc ),
294+ type = "Document" ,
295+ ).to_dict ()
296+ for doc in sorted_docs
297+ ]
285298 else :
286299 # Returning the prediction
287300 return Prediction (
288301 docs = [doc [self .text_column_name ] for doc in sorted_docs ],
289302 doc_ids = [self ._extract_doc_ids (doc ) for doc in sorted_docs ],
303+ doc_uris = [doc [self .docs_uri_column_name ] for doc in sorted_docs ],
290304 extra_columns = [self ._get_extra_columns (item ) for item in sorted_docs ],
291305 )
292306
293-
294307 @staticmethod
295308 def _query_via_databricks_sdk (
296309 index_name : str ,
0 commit comments