36
36
BasePydanticVectorStore ,
37
37
VectorStoreQuery ,
38
38
VectorStoreQueryResult ,
39
+ FilterOperator ,
40
+ MetadataFilters ,
41
+ MetadataFilter ,
39
42
)
40
43
41
44
if TYPE_CHECKING :
@@ -86,6 +89,31 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
86
89
return cast (T , wrapper )
87
90
88
91
92
+ def _get_connection (client : Any ) -> Connection | None :
93
+ # Dynamically import oracledb and the required classes
94
+ try :
95
+ import oracledb
96
+ except ImportError as e :
97
+ raise ImportError (
98
+ "Unable to import oracledb, please install with `pip install -U oracledb`."
99
+ ) from e
100
+
101
+ # check if ConnectionPool exists
102
+ connection_pool_class = getattr (oracledb , "ConnectionPool" , None )
103
+
104
+ if isinstance (client , oracledb .Connection ):
105
+ return client
106
+ elif connection_pool_class and isinstance (client , connection_pool_class ):
107
+ return client .acquire ()
108
+ else :
109
+ valid_types = "oracledb.Connection"
110
+ if connection_pool_class :
111
+ valid_types += " or oracledb.ConnectionPool"
112
+ raise TypeError (
113
+ f"Expected client of type { valid_types } , got { type (client ).__name__ } "
114
+ )
115
+
116
+
89
117
def _escape_str (value : str ) -> str :
90
118
BS = "\\ "
91
119
must_escape = (BS , "'" )
@@ -103,7 +131,7 @@ def _escape_str(value: str) -> str:
103
131
},
104
132
"node_info" : {
105
133
"type" : "JSON" ,
106
- "extract_func" : lambda x : json .dumps (x .node_info ),
134
+ "extract_func" : lambda x : json .dumps (x .get_node_info () ),
107
135
},
108
136
"metadata" : {
109
137
"type" : "JSON" ,
@@ -195,10 +223,11 @@ def _create_table(connection: Connection, table_name: str) -> None:
195
223
196
224
@_handle_exceptions
197
225
def create_index (
198
- connection : Connection ,
226
+ client : Any ,
199
227
vector_store : OraLlamaVS ,
200
228
params : Optional [dict [str , Any ]] = None ,
201
229
) -> None :
230
+ connection = _get_connection (client )
202
231
if params :
203
232
if params ["idx_type" ] == "HNSW" :
204
233
_create_hnsw_index (
@@ -350,7 +379,8 @@ def _create_ivf_index(
350
379
351
380
352
381
@_handle_exceptions
353
- def drop_table_purge (connection : Connection , table_name : str ) -> None :
382
+ def drop_table_purge (client : Any , table_name : str ) -> None :
383
+ connection = _get_connection (client )
354
384
if _table_exists (connection , table_name ):
355
385
cursor = connection .cursor ()
356
386
with cursor :
@@ -427,9 +457,10 @@ def __init__(
427
457
batch_size = batch_size ,
428
458
params = params ,
429
459
)
460
+ connection = _get_connection (_client )
430
461
# Assign _client to PrivateAttr after the Pydantic initialization
431
462
object .__setattr__ (self , "_client" , _client )
432
- _create_table (_client , table_name )
463
+ _create_table (connection , table_name )
433
464
434
465
except oracledb .DatabaseError as db_err :
435
466
logger .exception (f"Database error occurred while create table: { db_err } " )
@@ -456,26 +487,82 @@ def client(self) -> Any:
456
487
def class_name (cls ) -> str :
457
488
return "OraLlamaVS"
458
489
490
+ def _convert_oper_to_sql (
491
+ self ,
492
+ oper : FilterOperator ,
493
+ metadata_column : str ,
494
+ filter_key : str ,
495
+ value_bind : str ,
496
+ ) -> str :
497
+ if oper == FilterOperator .IS_EMPTY :
498
+ return f"NOT JSON_EXISTS({ metadata_column } , '$.{ filter_key } ') OR JSON_EQUAL(JSON_QUERY({ metadata_column } , '$.{ filter_key } '), '[]') OR JSON_EQUAL(JSON_QUERY({ metadata_column } , '$.{ filter_key } '), 'null')"
499
+ elif oper == FilterOperator .CONTAINS :
500
+ return f"JSON_EXISTS({ metadata_column } , '$.{ filter_key } [*]?(@ == $val)' PASSING { value_bind } AS \" val\" )"
501
+ else :
502
+ oper_map = {
503
+ FilterOperator .EQ : "{0} = {1}" , # default operator (string, int, float)
504
+ FilterOperator .GT : "{0} > {1}" , # greater than (int, float)
505
+ FilterOperator .LT : "{0} < {1}" , # less than (int, float)
506
+ FilterOperator .NE : "{0} != {1}" , # not equal to (string, int, float)
507
+ FilterOperator .GTE : "{0} >= {1}" , # greater than or equal to (int, float)
508
+ FilterOperator .LTE : "{0} <= {1}" , # less than or equal to (int, float)
509
+ FilterOperator .IN : "{0} IN ({1})" , # In array (string or number)
510
+ FilterOperator .NIN : "{0} NOT IN ({1})" , # Not in array (string or number)
511
+ FilterOperator .TEXT_MATCH : "{0} LIKE '%' || {1} || '%'" , # full text match (allows you to search for a specific substring, token or phrase within the text field)
512
+ }
513
+
514
+ if oper not in oper_map :
515
+ raise ValueError (
516
+ f"FilterOperation { oper } cannot be used with this vector store."
517
+ )
518
+
519
+ operation_f = oper_map .get (oper )
520
+
521
+ return operation_f .format (
522
+ f"JSON_VALUE({ metadata_column } , '$.{ filter_key } ')" , value_bind
523
+ )
524
+
525
+ def _get_filter_string (
526
+ self , filter : MetadataFilters | MetadataFilter , bind_variables : list
527
+ ) -> str :
528
+ if isinstance (filter , MetadataFilter ):
529
+ if not re .match (r"^[a-zA-Z0-9_]+$" , filter .key ):
530
+ raise ValueError (f"Invalid metadata key format: { filter .key } " )
531
+
532
+ value_bind = f""
533
+ if filter .operator == FilterOperator .IS_EMPTY :
534
+ # No values needed
535
+ pass
536
+ elif isinstance (filter .value , List ):
537
+ # Needs multiple binds for a list https://python-oracledb.readthedocs.io/en/latest/user_guide/bind.html#binding-multiple-values-to-a-sql-where-in-clause
538
+ value_binds = []
539
+ for val in filter .value :
540
+ value_binds .append (f":value{ len (bind_variables )} " )
541
+ bind_variables .append (val )
542
+ value_bind = "," .join (value_binds )
543
+ else :
544
+ value_bind = f":value{ len (bind_variables )} "
545
+ bind_variables .append (filter .value )
546
+
547
+ return self ._convert_oper_to_sql (
548
+ filter .operator , self .metadata_column , filter .key , value_bind
549
+ )
550
+
551
+ # Combine all sub filters
552
+ filter_strings = [
553
+ self ._get_filter_string (f_ , bind_variables ) for f_ in filter .filters
554
+ ]
555
+
556
+ return f" { filter .condition .value .upper ()} " .join (filter_strings )
557
+
459
558
def _append_meta_filter_condition (
460
- self , where_str : Optional [str ], exact_match_filter : list
559
+ self , where_str : Optional [str ], filters : Optional [ MetadataFilters ]
461
560
) -> Tuple [str , list ]:
462
561
bind_variables = []
463
- filter_conditions = []
464
-
465
- # Validate metadata keys (only allow alphanumeric and underscores)
466
- for filter_item in exact_match_filter :
467
- # Validate the key - only allow safe characters for JSON path
468
- if not re .match (r"^[a-zA-Z0-9_]+$" , filter_item .key ):
469
- raise ValueError (f"Invalid metadata key format: { filter_item .key } " )
470
- # Use JSON_VALUE with parameterized values
471
- filter_conditions .append (
472
- f"JSON_VALUE({ self .metadata_column } , '$.{ filter_item .key } ') = :value{ len (bind_variables )} "
473
- )
474
- bind_variables .append (filter_item .value )
475
562
476
- # Convert filter conditions to a single string
477
- filter_str = " AND " .join (filter_conditions )
563
+ filter_str = self ._get_filter_string (filters , bind_variables )
478
564
565
+ # Convert filter conditions to a single string
479
566
if where_str is None :
480
567
where_str = filter_str
481
568
else :
@@ -534,22 +621,25 @@ def add(self, nodes: list[BaseNode], **kwargs: Any) -> list[str]:
534
621
if not nodes :
535
622
return []
536
623
624
+ connection = _get_connection (self ._client )
625
+
537
626
for result_batch in iter_batch (nodes , self .batch_size ):
538
627
dml , bind_values = self ._build_insert (values = result_batch )
539
628
540
- with self . _client .cursor () as cursor :
629
+ with connection .cursor () as cursor :
541
630
# Use executemany to insert the batch
542
631
cursor .executemany (dml , bind_values )
543
- self . _client .commit ()
632
+ connection .commit ()
544
633
545
634
return [node .node_id for node in nodes ]
546
635
547
636
@_handle_exceptions
548
637
def delete (self , ref_doc_id : str , ** kwargs : Any ) -> None :
549
- with self ._client .cursor () as cursor :
638
+ connection = _get_connection (self ._client )
639
+ with connection .cursor () as cursor :
550
640
ddl = f"DELETE FROM { self .table_name } WHERE doc_id = :ref_doc_id"
551
641
cursor .execute (ddl , [ref_doc_id ])
552
- self . _client .commit ()
642
+ connection .commit ()
553
643
554
644
@_handle_exceptions
555
645
def _get_clob_value (self , result : Any ) -> str :
@@ -595,7 +685,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
595
685
bind_vars = []
596
686
if query .filters is not None :
597
687
where_str , bind_vars = self ._append_meta_filter_condition (
598
- where_str , query .filters . filters
688
+ where_str , query .filters
599
689
)
600
690
601
691
# build query sql
@@ -625,7 +715,9 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
625
715
params = {"embedding" : embedding }
626
716
for i , value in enumerate (bind_vars ):
627
717
params [f"value{ i } " ] = value
628
- with self ._client .cursor () as cursor :
718
+
719
+ connection = _get_connection (self ._client )
720
+ with connection .cursor () as cursor :
629
721
cursor .execute (query_sql , ** params )
630
722
results = cursor .fetchall ()
631
723
0 commit comments