5
5
6
6
import sqlalchemy
7
7
from llama_index .core .bridge .pydantic import BaseModel , Field
8
- from llama_index .core .vector_stores .types import VectorStoreQuery
9
8
from sqlalchemy .sql .selectable import Select
10
9
11
10
from llama_index .vector_stores .postgres .base import (
@@ -36,7 +35,17 @@ def get_bm25_data_model(
36
35
from pgvector .sqlalchemy import Vector , HALFVEC
37
36
from sqlalchemy import Column
38
37
from sqlalchemy .dialects .postgresql import BIGINT , JSON , JSONB , VARCHAR
39
- from sqlalchemy import cast , column , String , Integer , Numeric , Float , Boolean , Date , DateTime
38
+ from sqlalchemy import (
39
+ cast ,
40
+ column ,
41
+ String ,
42
+ Integer ,
43
+ Numeric ,
44
+ Float ,
45
+ Boolean ,
46
+ Date ,
47
+ DateTime ,
48
+ )
40
49
from sqlalchemy .dialects .postgresql import DOUBLE_PRECISION , UUID
41
50
from sqlalchemy .schema import Index
42
51
@@ -54,7 +63,7 @@ def get_bm25_data_model(
54
63
}
55
64
56
65
indexed_metadata_keys = indexed_metadata_keys or set ()
57
-
66
+
58
67
for key , pg_type in indexed_metadata_keys :
59
68
if pg_type not in pg_type_map :
60
69
raise ValueError (
@@ -67,7 +76,9 @@ def get_bm25_data_model(
67
76
indexname = f"{ index_name } _idx"
68
77
69
78
metadata_dtype = JSONB if use_jsonb else JSON
70
- embedding_col = Column (HALFVEC (embed_dim )) if use_halfvec else Column (Vector (embed_dim ))
79
+ embedding_col = (
80
+ Column (HALFVEC (embed_dim )) if use_halfvec else Column (Vector (embed_dim ))
81
+ )
71
82
72
83
metadata_indices = [
73
84
Index (
@@ -107,7 +118,7 @@ class BM25AbstractData(base):
107
118
class ParadeDBVectorStore (PGVectorStore , BaseModel ):
108
119
"""
109
120
ParadeDB Vector Store with BM25 search support.
110
-
121
+
111
122
Inherits from PGVectorStore and adds BM25 full-text search capabilities
112
123
using ParadeDB's pg_search extension.
113
124
@@ -130,16 +141,19 @@ class ParadeDBVectorStore(PGVectorStore, BaseModel):
130
141
use_halfvec=True
131
142
)
132
143
```
144
+
133
145
"""
134
146
135
147
connection_string : Optional [Union [str , sqlalchemy .engine .URL ]] = Field (default = None )
136
- async_connection_string : Optional [Union [str , sqlalchemy .engine .URL ]] = Field (default = None )
148
+ async_connection_string : Optional [Union [str , sqlalchemy .engine .URL ]] = Field (
149
+ default = None
150
+ )
137
151
table_name : Optional [str ] = Field (default = None )
138
152
schema_name : Optional [str ] = Field (default = "paradedb" )
139
153
hybrid_search : bool = Field (default = False )
140
154
text_search_config : str = Field (default = "english" )
141
155
embed_dim : int = Field (default = 1536 )
142
- cache_ok : bool = Field (default = False )
156
+ cache_ok : bool = Field (default = False )
143
157
perform_setup : bool = Field (default = True )
144
158
debug : bool = Field (default = False )
145
159
use_jsonb : bool = Field (default = False )
@@ -154,7 +168,7 @@ def __init__(
154
168
table_name : Optional [str ] = None ,
155
169
schema_name : Optional [str ] = None ,
156
170
hybrid_search : bool = False ,
157
- text_search_config : str = "english" ,
171
+ text_search_config : str = "english" ,
158
172
embed_dim : int = 1536 ,
159
173
cache_ok : bool = False ,
160
174
perform_setup : bool = True ,
@@ -176,7 +190,7 @@ def __init__(
176
190
self ,
177
191
connection_string = connection_string ,
178
192
async_connection_string = async_connection_string ,
179
- table_name = table_name ,
193
+ table_name = table_name ,
180
194
schema_name = schema_name or "paradedb" ,
181
195
hybrid_search = hybrid_search ,
182
196
text_search_config = text_search_config ,
@@ -187,14 +201,16 @@ def __init__(
187
201
use_jsonb = use_jsonb ,
188
202
hnsw_kwargs = hnsw_kwargs ,
189
203
create_engine_kwargs = create_engine_kwargs ,
190
- use_bm25 = use_bm25
204
+ use_bm25 = use_bm25 ,
191
205
)
192
-
206
+
193
207
# Call parent constructor
194
208
PGVectorStore .__init__ (
195
209
self ,
196
210
connection_string = str (connection_string ) if connection_string else None ,
197
- async_connection_string = str (async_connection_string ) if async_connection_string else None ,
211
+ async_connection_string = str (async_connection_string )
212
+ if async_connection_string
213
+ else None ,
198
214
table_name = table_name ,
199
215
schema_name = self .schema_name ,
200
216
hybrid_search = hybrid_search ,
@@ -213,10 +229,11 @@ def __init__(
213
229
indexed_metadata_keys = indexed_metadata_keys ,
214
230
customize_query_fn = customize_query_fn ,
215
231
)
216
-
232
+
217
233
# Override table model if using BM25
218
234
if self .use_bm25 :
219
235
from sqlalchemy .orm import declarative_base
236
+
220
237
self ._base = declarative_base ()
221
238
self ._table_class = get_bm25_data_model (
222
239
self ._base ,
@@ -270,6 +287,7 @@ def from_params(
270
287
271
288
Returns:
272
289
ParadeDBVectorStore: Instance of ParadeDBVectorStore.
290
+
273
291
"""
274
292
conn_str = (
275
293
connection_string
@@ -301,7 +319,7 @@ def from_params(
301
319
def _create_extension (self ) -> None :
302
320
"""Override to add pg_search extension for BM25."""
303
321
super ()._create_extension ()
304
-
322
+
305
323
if self .use_bm25 :
306
324
with self ._session () as session , session .begin ():
307
325
try :
@@ -337,7 +355,7 @@ def _initialize(self) -> None:
337
355
"""Override to add BM25 index creation."""
338
356
if not self ._is_initialized :
339
357
super ()._initialize ()
340
-
358
+
341
359
if self .use_bm25 and self .perform_setup :
342
360
try :
343
361
self ._create_bm25_index ()
@@ -355,10 +373,12 @@ def _build_sparse_query(
355
373
) -> Any :
356
374
"""Override to use BM25 if enabled, otherwise use parent's ts_vector."""
357
375
if not self .use_bm25 :
358
- return super ()._build_sparse_query (query_str , limit , metadata_filters , ** kwargs )
359
-
376
+ return super ()._build_sparse_query (
377
+ query_str , limit , metadata_filters , ** kwargs
378
+ )
379
+
360
380
from sqlalchemy import text
361
-
381
+
362
382
if query_str is None :
363
383
raise ValueError ("query_str must be specified for a sparse vector query." )
364
384
@@ -373,14 +393,12 @@ def _build_sparse_query(
373
393
if metadata_filters :
374
394
_logger .warning ("Metadata filters not fully implemented for BM25 raw SQL" )
375
395
376
- stmt = text (f"""
396
+ return text (f"""
377
397
{ base_query }
378
398
ORDER BY rank DESC
379
399
LIMIT :limit
380
400
""" ).bindparams (query = query_str_clean , limit = limit )
381
401
382
- return stmt
383
-
384
402
def _sparse_query_with_rank (
385
403
self ,
386
404
query_str : Optional [str ] = None ,
@@ -390,7 +408,7 @@ def _sparse_query_with_rank(
390
408
"""Override to handle BM25 results properly."""
391
409
if not self .use_bm25 :
392
410
return super ()._sparse_query_with_rank (query_str , limit , metadata_filters )
393
-
411
+
394
412
stmt = self ._build_sparse_query (query_str , limit , metadata_filters )
395
413
with self ._session () as session , session .begin ():
396
414
res = session .execute (stmt )
@@ -417,8 +435,10 @@ async def _async_sparse_query_with_rank(
417
435
) -> List [DBEmbeddingRow ]:
418
436
"""Override to handle async BM25 results properly."""
419
437
if not self .use_bm25 :
420
- return await super ()._async_sparse_query_with_rank (query_str , limit , metadata_filters )
421
-
438
+ return await super ()._async_sparse_query_with_rank (
439
+ query_str , limit , metadata_filters
440
+ )
441
+
422
442
stmt = self ._build_sparse_query (query_str , limit , metadata_filters )
423
443
async with self ._async_session () as session , session .begin ():
424
444
res = await session .execute (stmt )
@@ -435,4 +455,4 @@ async def _async_sparse_query_with_rank(
435
455
similarity = item .rank ,
436
456
)
437
457
for item in res .all ()
438
- ]
458
+ ]
0 commit comments