Skip to content

Commit 1e77573

Browse files
committed
fix: run linter and format
1 parent 3f91652 commit 1e77573

File tree

2 files changed

+76
-53
lines changed
  • llama-index-integrations/vector_stores/llama-index-vector-store-paradedb

2 files changed

+76
-53
lines changed

llama-index-integrations/vector_stores/llama-index-vector-store-paradedb/llama_index/vector_stores/paradedb/base.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import sqlalchemy
77
from llama_index.core.bridge.pydantic import BaseModel, Field
8-
from llama_index.core.vector_stores.types import VectorStoreQuery
98
from sqlalchemy.sql.selectable import Select
109

1110
from llama_index.vector_stores.postgres.base import (
@@ -36,7 +35,17 @@ def get_bm25_data_model(
3635
from pgvector.sqlalchemy import Vector, HALFVEC
3736
from sqlalchemy import Column
3837
from sqlalchemy.dialects.postgresql import BIGINT, JSON, JSONB, VARCHAR
39-
from sqlalchemy import cast, column, String, Integer, Numeric, Float, Boolean, Date, DateTime
38+
from sqlalchemy import (
39+
cast,
40+
column,
41+
String,
42+
Integer,
43+
Numeric,
44+
Float,
45+
Boolean,
46+
Date,
47+
DateTime,
48+
)
4049
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, UUID
4150
from sqlalchemy.schema import Index
4251

@@ -54,7 +63,7 @@ def get_bm25_data_model(
5463
}
5564

5665
indexed_metadata_keys = indexed_metadata_keys or set()
57-
66+
5867
for key, pg_type in indexed_metadata_keys:
5968
if pg_type not in pg_type_map:
6069
raise ValueError(
@@ -67,7 +76,9 @@ def get_bm25_data_model(
6776
indexname = f"{index_name}_idx"
6877

6978
metadata_dtype = JSONB if use_jsonb else JSON
70-
embedding_col = Column(HALFVEC(embed_dim)) if use_halfvec else Column(Vector(embed_dim))
79+
embedding_col = (
80+
Column(HALFVEC(embed_dim)) if use_halfvec else Column(Vector(embed_dim))
81+
)
7182

7283
metadata_indices = [
7384
Index(
@@ -107,7 +118,7 @@ class BM25AbstractData(base):
107118
class ParadeDBVectorStore(PGVectorStore, BaseModel):
108119
"""
109120
ParadeDB Vector Store with BM25 search support.
110-
121+
111122
Inherits from PGVectorStore and adds BM25 full-text search capabilities
112123
using ParadeDB's pg_search extension.
113124
@@ -130,16 +141,19 @@ class ParadeDBVectorStore(PGVectorStore, BaseModel):
130141
use_halfvec=True
131142
)
132143
```
144+
133145
"""
134146

135147
connection_string: Optional[Union[str, sqlalchemy.engine.URL]] = Field(default=None)
136-
async_connection_string: Optional[Union[str, sqlalchemy.engine.URL]] = Field(default=None)
148+
async_connection_string: Optional[Union[str, sqlalchemy.engine.URL]] = Field(
149+
default=None
150+
)
137151
table_name: Optional[str] = Field(default=None)
138152
schema_name: Optional[str] = Field(default="paradedb")
139153
hybrid_search: bool = Field(default=False)
140154
text_search_config: str = Field(default="english")
141155
embed_dim: int = Field(default=1536)
142-
cache_ok: bool = Field(default=False)
156+
cache_ok: bool = Field(default=False)
143157
perform_setup: bool = Field(default=True)
144158
debug: bool = Field(default=False)
145159
use_jsonb: bool = Field(default=False)
@@ -154,7 +168,7 @@ def __init__(
154168
table_name: Optional[str] = None,
155169
schema_name: Optional[str] = None,
156170
hybrid_search: bool = False,
157-
text_search_config: str = "english",
171+
text_search_config: str = "english",
158172
embed_dim: int = 1536,
159173
cache_ok: bool = False,
160174
perform_setup: bool = True,
@@ -176,7 +190,7 @@ def __init__(
176190
self,
177191
connection_string=connection_string,
178192
async_connection_string=async_connection_string,
179-
table_name=table_name,
193+
table_name=table_name,
180194
schema_name=schema_name or "paradedb",
181195
hybrid_search=hybrid_search,
182196
text_search_config=text_search_config,
@@ -187,14 +201,16 @@ def __init__(
187201
use_jsonb=use_jsonb,
188202
hnsw_kwargs=hnsw_kwargs,
189203
create_engine_kwargs=create_engine_kwargs,
190-
use_bm25=use_bm25
204+
use_bm25=use_bm25,
191205
)
192-
206+
193207
# Call parent constructor
194208
PGVectorStore.__init__(
195209
self,
196210
connection_string=str(connection_string) if connection_string else None,
197-
async_connection_string=str(async_connection_string) if async_connection_string else None,
211+
async_connection_string=str(async_connection_string)
212+
if async_connection_string
213+
else None,
198214
table_name=table_name,
199215
schema_name=self.schema_name,
200216
hybrid_search=hybrid_search,
@@ -213,10 +229,11 @@ def __init__(
213229
indexed_metadata_keys=indexed_metadata_keys,
214230
customize_query_fn=customize_query_fn,
215231
)
216-
232+
217233
# Override table model if using BM25
218234
if self.use_bm25:
219235
from sqlalchemy.orm import declarative_base
236+
220237
self._base = declarative_base()
221238
self._table_class = get_bm25_data_model(
222239
self._base,
@@ -270,6 +287,7 @@ def from_params(
270287
271288
Returns:
272289
ParadeDBVectorStore: Instance of ParadeDBVectorStore.
290+
273291
"""
274292
conn_str = (
275293
connection_string
@@ -301,7 +319,7 @@ def from_params(
301319
def _create_extension(self) -> None:
302320
"""Override to add pg_search extension for BM25."""
303321
super()._create_extension()
304-
322+
305323
if self.use_bm25:
306324
with self._session() as session, session.begin():
307325
try:
@@ -337,7 +355,7 @@ def _initialize(self) -> None:
337355
"""Override to add BM25 index creation."""
338356
if not self._is_initialized:
339357
super()._initialize()
340-
358+
341359
if self.use_bm25 and self.perform_setup:
342360
try:
343361
self._create_bm25_index()
@@ -355,10 +373,12 @@ def _build_sparse_query(
355373
) -> Any:
356374
"""Override to use BM25 if enabled, otherwise use parent's ts_vector."""
357375
if not self.use_bm25:
358-
return super()._build_sparse_query(query_str, limit, metadata_filters, **kwargs)
359-
376+
return super()._build_sparse_query(
377+
query_str, limit, metadata_filters, **kwargs
378+
)
379+
360380
from sqlalchemy import text
361-
381+
362382
if query_str is None:
363383
raise ValueError("query_str must be specified for a sparse vector query.")
364384

@@ -373,14 +393,12 @@ def _build_sparse_query(
373393
if metadata_filters:
374394
_logger.warning("Metadata filters not fully implemented for BM25 raw SQL")
375395

376-
stmt = text(f"""
396+
return text(f"""
377397
{base_query}
378398
ORDER BY rank DESC
379399
LIMIT :limit
380400
""").bindparams(query=query_str_clean, limit=limit)
381401

382-
return stmt
383-
384402
def _sparse_query_with_rank(
385403
self,
386404
query_str: Optional[str] = None,
@@ -390,7 +408,7 @@ def _sparse_query_with_rank(
390408
"""Override to handle BM25 results properly."""
391409
if not self.use_bm25:
392410
return super()._sparse_query_with_rank(query_str, limit, metadata_filters)
393-
411+
394412
stmt = self._build_sparse_query(query_str, limit, metadata_filters)
395413
with self._session() as session, session.begin():
396414
res = session.execute(stmt)
@@ -417,8 +435,10 @@ async def _async_sparse_query_with_rank(
417435
) -> List[DBEmbeddingRow]:
418436
"""Override to handle async BM25 results properly."""
419437
if not self.use_bm25:
420-
return await super()._async_sparse_query_with_rank(query_str, limit, metadata_filters)
421-
438+
return await super()._async_sparse_query_with_rank(
439+
query_str, limit, metadata_filters
440+
)
441+
422442
stmt = self._build_sparse_query(query_str, limit, metadata_filters)
423443
async with self._async_session() as session, session.begin():
424444
res = await session.execute(stmt)
@@ -435,4 +455,4 @@ async def _async_sparse_query_with_rank(
435455
similarity=item.rank,
436456
)
437457
for item in res.all()
438-
]
458+
]

llama-index-integrations/vector_stores/llama-index-vector-store-paradedb/tests/test_paradedb.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def _get_sample_vector(num: float) -> List[float]:
4949
@pytest.fixture(scope="session")
5050
def conn() -> Any:
5151
import psycopg2
52+
5253
return psycopg2.connect(**PARAMS) # type: ignore
5354

5455

@@ -434,24 +435,28 @@ async def test_bm25_extensions_created(db: None) -> None:
434435
hybrid_search=True,
435436
embed_dim=TEST_EMBED_DIM,
436437
)
437-
438+
438439
# Force initialization
439-
pg.add([
440-
TextNode(
441-
text="test",
442-
id_="test",
443-
embedding=_get_sample_vector(1.0),
444-
)
445-
])
446-
440+
pg.add(
441+
[
442+
TextNode(
443+
text="test",
444+
id_="test",
445+
embedding=_get_sample_vector(1.0),
446+
)
447+
]
448+
)
449+
447450
# Check that both extensions exist
448451
with psycopg2.connect(**PARAMS, database=TEST_DB) as conn:
449452
with conn.cursor() as c:
450-
c.execute("SELECT COUNT(*) FROM pg_extension WHERE extname IN ('vector', 'pg_search');")
453+
c.execute(
454+
"SELECT COUNT(*) FROM pg_extension WHERE extname IN ('vector', 'pg_search');"
455+
)
451456
ext_count = c.fetchone()[0]
452-
457+
453458
assert ext_count == 2, "Both 'vector' and 'pg_search' extensions should exist"
454-
459+
455460
await pg.close()
456461

457462

@@ -464,39 +469,37 @@ async def test_paradedb_inherits_pgvector_functionality(
464469
"""Test that ParadeDBVectorStore inherits all PGVectorStore functionality."""
465470
# Add nodes
466471
pg_bm25.add(hybrid_node_embeddings)
467-
472+
468473
# Test vector-only query (inherited from PGVectorStore)
469474
q = VectorStoreQuery(
470475
query_embedding=_get_sample_vector(0.1),
471476
similarity_top_k=2,
472477
mode=VectorStoreQueryMode.DEFAULT,
473478
)
474-
479+
475480
res = pg_bm25.query(q)
476481
assert res.nodes
477482
assert len(res.nodes) == 2
478-
483+
479484
# Test delete (inherited)
480485
pg_bm25.delete_nodes(["aaa"])
481-
486+
482487
res = pg_bm25.query(q)
483488
assert "aaa" not in res.ids
484-
489+
485490
# Test clear (inherited)
486491
await pg_bm25.aclear()
487-
492+
488493
res = pg_bm25.query(q)
489494
assert len(res.nodes) == 0
490495

491496

492497
@pytest.mark.skipif(postgres_not_available, reason="postgres db is not available")
493498
@pytest.mark.asyncio
494499
async def test_bm25_vs_tsvector_different_results(
495-
db: None,
496-
hybrid_node_embeddings: List[TextNode]
497-
) -> None:
500+
db: None, hybrid_node_embeddings: List[TextNode]
501+
) -> None:
498502
"""Test that BM25 and ts_vector can produce different ranking results."""
499-
500503
# Create both stores
501504
pg_tsvector = PGVectorStore.from_params(
502505
**PARAMS, # type: ignore
@@ -506,7 +509,7 @@ async def test_bm25_vs_tsvector_different_results(
506509
hybrid_search=True,
507510
embed_dim=TEST_EMBED_DIM,
508511
)
509-
512+
510513
pg_bm25 = ParadeDBVectorStore.from_params(
511514
**PARAMS, # type: ignore
512515
database=TEST_DB,
@@ -518,14 +521,14 @@ async def test_bm25_vs_tsvector_different_results(
518521
)
519522
pg_tsvector.add(hybrid_node_embeddings)
520523
pg_bm25.add(hybrid_node_embeddings)
521-
524+
522525
q = VectorStoreQuery(
523526
query_str="fox",
524527
sparse_top_k=2,
525528
mode=VectorStoreQueryMode.SPARSE,
526529
query_embedding=_get_sample_vector(5.0),
527530
)
528-
531+
529532
res_tsvector = pg_tsvector.query(q)
530533
res_bm25 = pg_bm25.query(q)
531534

@@ -538,11 +541,11 @@ async def test_bm25_vs_tsvector_different_results(
538541
# Both should return results
539542
assert len(res_tsvector.nodes) == 2
540543
assert len(res_bm25.nodes) == 2
541-
544+
542545
# BM25 uses BM25 ranking, ts_vector uses ts_rank
543546
# The implementation difference is verified
544547
assert pg_bm25.use_bm25 is True
545548
assert not hasattr(pg_tsvector, "use_bm25") or pg_tsvector.use_bm25 is False
546-
549+
547550
await pg_tsvector.close()
548-
await pg_bm25.close()
551+
await pg_bm25.close()

0 commit comments

Comments
 (0)