Skip to content

Commit 09d3c39

Browse files
authored
OraLlamaVS Connection Pool Support + Filtering (#19412)
1 parent 3c39596 commit 09d3c39

File tree

4 files changed

+127
-60
lines changed

4 files changed

+127
-60
lines changed

docs/docs/examples/vector_stores/orallamavs.ipynb

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,6 @@
305305
" - vector_stores (list): A list of OracleVS instances.\n",
306306
" \"\"\"\n",
307307
" for i, vs in enumerate(vector_stores, start=1):\n",
308-
" # Adding texts\n",
309-
" try:\n",
310-
" vs.add_texts(text_nodes, metadata)\n",
311-
" print(f\"\\n\\n\\nAdd texts complete for vector store {i}\\n\\n\\n\")\n",
312-
" except Exception as ex:\n",
313-
" print(\n",
314-
" f\"\\n\\n\\nExpected error on duplicate add for vector store {i}\\n\\n\\n\"\n",
315-
" )\n",
316-
"\n",
317308
" # Deleting texts using the value of 'id'\n",
318309
" vs.delete(\"test-1\")\n",
319310
" print(f\"\\n\\n\\nDelete texts complete for vector store {i}\\n\\n\\n\")\n",

llama-index-integrations/vector_stores/llama-index-vector-stores-oracledb/README.md

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -123,34 +123,18 @@ orallamavs.create_index(
123123

124124
print("Index created.")
125125

126-
127126
# Perform Semantic Search
128-
query = "What is Oracle AI Vector Store?"
129-
filter = {"document_id": ["1"]}
130-
127+
embedding = embedder._get_text_embedding("What is Oracle AI Vector Store?")
128+
query = VectorStoreQuery(query_embedding=embedding, similarity_top_k=1)
131129
# Similarity search without a filter
132-
print(vectorstore.similarity_search(query, 1))
133-
134-
# Similarity search with a filter
135-
print(vectorstore.similarity_search(query, 1, filter=filter))
130+
print(vectorstore.query(query))
136131

137-
# Similarity search with relevance score
138-
print(vectorstore.similarity_search_with_score(query, 1))
139-
140-
# Similarity search with relevance score with filter
141-
print(vectorstore.similarity_search_with_score(query, 1, filter=filter))
142-
143-
# Max marginal relevance search
144-
print(
145-
vectorstore.max_marginal_relevance_search(
146-
query, 1, fetch_k=20, lambda_mult=0.5
147-
)
132+
filters = MetadataFilters(
133+
filters=[ExactMatchFilter(key="document_id", value="1")]
148134
)
149-
150-
# Max marginal relevance search with filter
151-
print(
152-
vectorstore.max_marginal_relevance_search(
153-
query, 1, fetch_k=20, lambda_mult=0.5, filter=filter
154-
)
135+
query = VectorStoreQuery(
136+
query_embedding=embedding, filters=filters, similarity_top_k=1
155137
)
138+
# Similarity search with a filter
139+
print(vectorstore.query(query))
156140
```

llama-index-integrations/vector_stores/llama-index-vector-stores-oracledb/llama_index/vector_stores/oracledb/base.py

Lines changed: 117 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
BasePydanticVectorStore,
3737
VectorStoreQuery,
3838
VectorStoreQueryResult,
39+
FilterOperator,
40+
MetadataFilters,
41+
MetadataFilter,
3942
)
4043

4144
if TYPE_CHECKING:
@@ -86,6 +89,31 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
8689
return cast(T, wrapper)
8790

8891

92+
def _get_connection(client: Any) -> Connection | None:
93+
# Dynamically import oracledb and the required classes
94+
try:
95+
import oracledb
96+
except ImportError as e:
97+
raise ImportError(
98+
"Unable to import oracledb, please install with `pip install -U oracledb`."
99+
) from e
100+
101+
# check if ConnectionPool exists
102+
connection_pool_class = getattr(oracledb, "ConnectionPool", None)
103+
104+
if isinstance(client, oracledb.Connection):
105+
return client
106+
elif connection_pool_class and isinstance(client, connection_pool_class):
107+
return client.acquire()
108+
else:
109+
valid_types = "oracledb.Connection"
110+
if connection_pool_class:
111+
valid_types += " or oracledb.ConnectionPool"
112+
raise TypeError(
113+
f"Expected client of type {valid_types}, got {type(client).__name__}"
114+
)
115+
116+
89117
def _escape_str(value: str) -> str:
90118
BS = "\\"
91119
must_escape = (BS, "'")
@@ -103,7 +131,7 @@ def _escape_str(value: str) -> str:
103131
},
104132
"node_info": {
105133
"type": "JSON",
106-
"extract_func": lambda x: json.dumps(x.node_info),
134+
"extract_func": lambda x: json.dumps(x.get_node_info()),
107135
},
108136
"metadata": {
109137
"type": "JSON",
@@ -195,10 +223,11 @@ def _create_table(connection: Connection, table_name: str) -> None:
195223

196224
@_handle_exceptions
197225
def create_index(
198-
connection: Connection,
226+
client: Any,
199227
vector_store: OraLlamaVS,
200228
params: Optional[dict[str, Any]] = None,
201229
) -> None:
230+
connection = _get_connection(client)
202231
if params:
203232
if params["idx_type"] == "HNSW":
204233
_create_hnsw_index(
@@ -350,7 +379,8 @@ def _create_ivf_index(
350379

351380

352381
@_handle_exceptions
353-
def drop_table_purge(connection: Connection, table_name: str) -> None:
382+
def drop_table_purge(client: Any, table_name: str) -> None:
383+
connection = _get_connection(client)
354384
if _table_exists(connection, table_name):
355385
cursor = connection.cursor()
356386
with cursor:
@@ -427,9 +457,10 @@ def __init__(
427457
batch_size=batch_size,
428458
params=params,
429459
)
460+
connection = _get_connection(_client)
430461
# Assign _client to PrivateAttr after the Pydantic initialization
431462
object.__setattr__(self, "_client", _client)
432-
_create_table(_client, table_name)
463+
_create_table(connection, table_name)
433464

434465
except oracledb.DatabaseError as db_err:
435466
logger.exception(f"Database error occurred while create table: {db_err}")
@@ -456,26 +487,82 @@ def client(self) -> Any:
456487
def class_name(cls) -> str:
457488
return "OraLlamaVS"
458489

490+
def _convert_oper_to_sql(
491+
self,
492+
oper: FilterOperator,
493+
metadata_column: str,
494+
filter_key: str,
495+
value_bind: str,
496+
) -> str:
497+
if oper == FilterOperator.IS_EMPTY:
498+
return f"NOT JSON_EXISTS({metadata_column}, '$.{filter_key}') OR JSON_EQUAL(JSON_QUERY({metadata_column}, '$.{filter_key}'), '[]') OR JSON_EQUAL(JSON_QUERY({metadata_column}, '$.{filter_key}'), 'null')"
499+
elif oper == FilterOperator.CONTAINS:
500+
return f"JSON_EXISTS({metadata_column}, '$.{filter_key}[*]?(@ == $val)' PASSING {value_bind} AS \"val\")"
501+
else:
502+
oper_map = {
503+
FilterOperator.EQ: "{0} = {1}", # default operator (string, int, float)
504+
FilterOperator.GT: "{0} > {1}", # greater than (int, float)
505+
FilterOperator.LT: "{0} < {1}", # less than (int, float)
506+
FilterOperator.NE: "{0} != {1}", # not equal to (string, int, float)
507+
FilterOperator.GTE: "{0} >= {1}", # greater than or equal to (int, float)
508+
FilterOperator.LTE: "{0} <= {1}", # less than or equal to (int, float)
509+
FilterOperator.IN: "{0} IN ({1})", # In array (string or number)
510+
FilterOperator.NIN: "{0} NOT IN ({1})", # Not in array (string or number)
511+
FilterOperator.TEXT_MATCH: "{0} LIKE '%' || {1} || '%'", # full text match (allows you to search for a specific substring, token or phrase within the text field)
512+
}
513+
514+
if oper not in oper_map:
515+
raise ValueError(
516+
f"FilterOperation {oper} cannot be used with this vector store."
517+
)
518+
519+
operation_f = oper_map.get(oper)
520+
521+
return operation_f.format(
522+
f"JSON_VALUE({metadata_column}, '$.{filter_key}')", value_bind
523+
)
524+
525+
def _get_filter_string(
526+
self, filter: MetadataFilters | MetadataFilter, bind_variables: list
527+
) -> str:
528+
if isinstance(filter, MetadataFilter):
529+
if not re.match(r"^[a-zA-Z0-9_]+$", filter.key):
530+
raise ValueError(f"Invalid metadata key format: {filter.key}")
531+
532+
value_bind = f""
533+
if filter.operator == FilterOperator.IS_EMPTY:
534+
# No values needed
535+
pass
536+
elif isinstance(filter.value, List):
537+
# Needs multiple binds for a list https://python-oracledb.readthedocs.io/en/latest/user_guide/bind.html#binding-multiple-values-to-a-sql-where-in-clause
538+
value_binds = []
539+
for val in filter.value:
540+
value_binds.append(f":value{len(bind_variables)}")
541+
bind_variables.append(val)
542+
value_bind = ",".join(value_binds)
543+
else:
544+
value_bind = f":value{len(bind_variables)}"
545+
bind_variables.append(filter.value)
546+
547+
return self._convert_oper_to_sql(
548+
filter.operator, self.metadata_column, filter.key, value_bind
549+
)
550+
551+
# Combine all sub filters
552+
filter_strings = [
553+
self._get_filter_string(f_, bind_variables) for f_ in filter.filters
554+
]
555+
556+
return f" {filter.condition.value.upper()} ".join(filter_strings)
557+
459558
def _append_meta_filter_condition(
460-
self, where_str: Optional[str], exact_match_filter: list
559+
self, where_str: Optional[str], filters: Optional[MetadataFilters]
461560
) -> Tuple[str, list]:
462561
bind_variables = []
463-
filter_conditions = []
464-
465-
# Validate metadata keys (only allow alphanumeric and underscores)
466-
for filter_item in exact_match_filter:
467-
# Validate the key - only allow safe characters for JSON path
468-
if not re.match(r"^[a-zA-Z0-9_]+$", filter_item.key):
469-
raise ValueError(f"Invalid metadata key format: {filter_item.key}")
470-
# Use JSON_VALUE with parameterized values
471-
filter_conditions.append(
472-
f"JSON_VALUE({self.metadata_column}, '$.{filter_item.key}') = :value{len(bind_variables)}"
473-
)
474-
bind_variables.append(filter_item.value)
475562

476-
# Convert filter conditions to a single string
477-
filter_str = " AND ".join(filter_conditions)
563+
filter_str = self._get_filter_string(filters, bind_variables)
478564

565+
# Convert filter conditions to a single string
479566
if where_str is None:
480567
where_str = filter_str
481568
else:
@@ -534,22 +621,25 @@ def add(self, nodes: list[BaseNode], **kwargs: Any) -> list[str]:
534621
if not nodes:
535622
return []
536623

624+
connection = _get_connection(self._client)
625+
537626
for result_batch in iter_batch(nodes, self.batch_size):
538627
dml, bind_values = self._build_insert(values=result_batch)
539628

540-
with self._client.cursor() as cursor:
629+
with connection.cursor() as cursor:
541630
# Use executemany to insert the batch
542631
cursor.executemany(dml, bind_values)
543-
self._client.commit()
632+
connection.commit()
544633

545634
return [node.node_id for node in nodes]
546635

547636
@_handle_exceptions
548637
def delete(self, ref_doc_id: str, **kwargs: Any) -> None:
549-
with self._client.cursor() as cursor:
638+
connection = _get_connection(self._client)
639+
with connection.cursor() as cursor:
550640
ddl = f"DELETE FROM {self.table_name} WHERE doc_id = :ref_doc_id"
551641
cursor.execute(ddl, [ref_doc_id])
552-
self._client.commit()
642+
connection.commit()
553643

554644
@_handle_exceptions
555645
def _get_clob_value(self, result: Any) -> str:
@@ -595,7 +685,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
595685
bind_vars = []
596686
if query.filters is not None:
597687
where_str, bind_vars = self._append_meta_filter_condition(
598-
where_str, query.filters.filters
688+
where_str, query.filters
599689
)
600690

601691
# build query sql
@@ -625,7 +715,9 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
625715
params = {"embedding": embedding}
626716
for i, value in enumerate(bind_vars):
627717
params[f"value{i}"] = value
628-
with self._client.cursor() as cursor:
718+
719+
connection = _get_connection(self._client)
720+
with connection.cursor() as cursor:
629721
cursor.execute(query_sql, **params)
630722
results = cursor.fetchall()
631723

llama-index-integrations/vector_stores/llama-index-vector-stores-oracledb/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dev = [
2626

2727
[project]
2828
name = "llama-index-vector-stores-oracledb"
29-
version = "0.3.1"
29+
version = "0.3.2"
3030
description = "llama-index vector_stores oracle database integration"
3131
authors = [{name = "Your Name", email = "[email protected]"}]
3232
requires-python = ">=3.9,<3.13"

0 commit comments

Comments
 (0)