|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +import re |
14 | 15 |
|
15 | 16 | from llama_index.core import ( |
16 | 17 | Document, |
@@ -40,7 +41,19 @@ class OpensearchKnowledgeBackend(BaseKnowledgebaseBackend): |
40 | 41 | embedding_config: EmbeddingModelConfig = Field(default_factory=EmbeddingModelConfig) |
41 | 42 | """Embedding model configs""" |
42 | 43 |
|
| 44 | + def precheck_index_naming(self): |
| 45 | + if not ( |
| 46 | + isinstance(self.index, str) |
| 47 | + and not self.index.startswith(("_", "-")) |
| 48 | + and self.index.islower() |
| 49 | + and re.match(r"^[a-z0-9_\-.]+$", self.index) |
| 50 | + ): |
| 51 | + raise ValueError( |
| 52 | + "The index name does not conform to the naming rules of OpenSearch" |
| 53 | + ) |
| 54 | + |
43 | 55 | def model_post_init(self, __context: Any) -> None: |
| 56 | + self.precheck_index_naming() |
44 | 57 | self._opensearch_client = OpensearchVectorClient( |
45 | 58 | endpoint=self.opensearch_config.host, |
46 | 59 | port=self.opensearch_config.port, |
@@ -71,7 +84,6 @@ def model_post_init(self, __context: Any) -> None: |
71 | 84 | storage_context=self._storage_context, |
72 | 85 | embed_model=self._embed_model, |
73 | 86 | ) |
74 | | - self._retriever = self._vector_index.as_retriever() |
75 | 87 |
|
76 | 88 | @override |
77 | 89 | def add_from_directory(self, directory: str) -> bool: |
@@ -99,12 +111,8 @@ def add_from_text(self, text: str | list[str]) -> bool: |
99 | 111 |
|
100 | 112 | @override |
101 | 113 | def search(self, query: str, top_k: int = 5) -> list[str]: |
102 | | - _original_top_k = self._retriever.similarity_top_k # type: ignore |
103 | | - self._retriever.similarity_top_k = top_k # type: ignore |
104 | | - |
105 | | - retrieved_nodes = self._retriever.retrieve(query) |
106 | | - |
107 | | - self._retriever.similarity_top_k = _original_top_k # type: ignore |
| 114 | + _retriever = self._vector_index.as_retriever(similarity_top_k=top_k) |
| 115 | + retrieved_nodes = _retriever.retrieve(query) |
108 | 116 | return [node.text for node in retrieved_nodes] |
109 | 117 |
|
110 | 118 | def _split_documents(self, documents: list[Document]) -> list[BaseNode]: |
|
0 commit comments