Skip to content

Commit 359dbaf

Browse files
committed
refactor(vector-storage): use index enums and add helper tests
- Use IndexOperation enum for embedding write index_status mapping - Improve _validate_and_prepare_table and _process_batch docstrings - Add tests for grouping embeddings, table validation and spill retry - Extend IndexOperation tests to cover READY status Made-with: Cursor
1 parent f774487 commit 359dbaf

File tree

4 files changed

+269
-14
lines changed

4 files changed

+269
-14
lines changed

src/xagent/core/tools/core/RAG_tools/core/schemas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,10 @@ class FTSIndexStatus(Enum):
184184

185185

186186
class IndexOperation(Enum):
187-
"""Index operation result types."""
187+
"""Index operation result types (e.g. for embedding write response)."""
188188

189189
CREATED = "created"
190+
READY = "ready"
190191
SKIPPED = "skipped"
191192
SKIPPED_THRESHOLD = "skipped_threshold"
192193
FAILED = "failed"

src/xagent/core/tools/core/RAG_tools/vector_storage/vector_manager.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
ChunkForEmbedding,
3333
EmbeddingReadResponse,
3434
EmbeddingWriteResponse,
35+
IndexOperation,
3536
)
3637
from ..LanceDB.model_tag_utils import to_model_tag
3738
from ..LanceDB.schema_manager import ensure_chunks_table, ensure_embeddings_table
@@ -489,7 +490,20 @@ def _validate_and_prepare_table(
489490
table_name: str,
490491
vector_dim: int,
491492
) -> Any:
492-
"""Ensure database table exists and has compatible schema."""
493+
"""Ensure database table exists and has compatible schema.
494+
495+
If the table exists, checks the vector field type and dimension; drops and
496+
recreates the table when dimension or type is incompatible.
497+
498+
Args:
499+
conn: LanceDB connection (e.g. from get_connection_from_env).
500+
model_tag: Model tag used for table naming (e.g. from to_model_tag).
501+
table_name: Full embeddings table name (e.g. embeddings_<model_tag>).
502+
vector_dim: Expected vector dimension for the table schema.
503+
504+
Returns:
505+
LanceDB table instance for the embeddings table (existing or newly created).
506+
"""
493507
conn_any = cast(Any, conn)
494508
try:
495509
existing_tables: List[str] = []
@@ -542,8 +556,18 @@ def _process_batch(
542556
) -> int:
543557
"""Process a single batch of embeddings.
544558
559+
Uses merge_insert for upsert; on recoverable errors falls back to add().
560+
Non-recoverable errors (schema/type/dimension) are re-raised without fallback.
561+
562+
Args:
563+
embeddings_table: LanceDB table to write to (from _validate_and_prepare_table).
564+
records_to_merge: List of record dicts with keys matching table schema.
565+
batch_idx: Zero-based batch index (for logging).
566+
total_batches: Total number of batches (for logging).
567+
model: Model name (for logging).
568+
545569
Returns:
546-
Number of upserted records.
570+
Number of upserted records (len(records_to_merge) on success).
547571
"""
548572
try:
549573
# Try merge_insert first (preferred method for upserts)
@@ -805,7 +829,7 @@ def _process_model_embeddings(
805829
logger.info("Processed model %s: upserted %d embeddings", model, upserted_count)
806830

807831
# Handle index creation and reindexing if requested
808-
index_status = "skipped"
832+
index_status: str = IndexOperation.SKIPPED.value
809833
if create_index:
810834
try:
811835
# Use index manager for index creation
@@ -826,7 +850,7 @@ def _process_model_embeddings(
826850

827851
except Exception as index_error: # noqa: BLE001
828852
logger.warning("Failed to create index for %s: %s", table_name, index_error)
829-
index_status = "failed"
853+
index_status = IndexOperation.FAILED.value
830854

831855
return upserted_count, index_status
832856

@@ -840,7 +864,9 @@ def write_vectors_to_db(
840864
"""Write embedding vectors to database with idempotency."""
841865
if not embeddings:
842866
return EmbeddingWriteResponse(
843-
upsert_count=0, deleted_stale_count=0, index_status="skipped"
867+
upsert_count=0,
868+
deleted_stale_count=0,
869+
index_status=IndexOperation.SKIPPED.value,
844870
)
845871

846872
try:
@@ -865,28 +891,28 @@ def write_vectors_to_db(
865891
total_upserted += upserted
866892
index_statuses.append(idx_status)
867893

868-
# Determine overall index status
894+
# Determine overall index status (map index_manager strings to IndexOperation)
869895
if "index_building" in index_statuses:
870-
overall_index_status = "created"
896+
overall_index_status = IndexOperation.CREATED
871897
elif "index_ready" in index_statuses:
872-
overall_index_status = "ready"
898+
overall_index_status = IndexOperation.READY
873899
elif "failed" in index_statuses or "index_corrupted" in index_statuses:
874-
overall_index_status = "failed"
900+
overall_index_status = IndexOperation.FAILED
875901
elif "below_threshold" in index_statuses:
876-
overall_index_status = "skipped_threshold"
902+
overall_index_status = IndexOperation.SKIPPED_THRESHOLD
877903
else:
878-
overall_index_status = "skipped"
904+
overall_index_status = IndexOperation.SKIPPED
879905

880906
logger.info(
881907
"Embedding write completed: %d upserted, index status: %s",
882908
total_upserted,
883-
overall_index_status,
909+
overall_index_status.value,
884910
)
885911

886912
return EmbeddingWriteResponse(
887913
upsert_count=total_upserted,
888914
deleted_stale_count=0, # merge_insert handles updates automatically
889-
index_status=overall_index_status,
915+
index_status=overall_index_status.value,
890916
)
891917

892918
except Exception as e:

tests/core/tools/core/RAG_tools/core/test_schemas.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,7 @@ class TestIndexOperation:
826826
def test_enum_values(self):
827827
"""Test that enum has expected values."""
828828
assert IndexOperation.CREATED.value == "created"
829+
assert IndexOperation.READY.value == "ready"
829830
assert IndexOperation.SKIPPED.value == "skipped"
830831
assert IndexOperation.SKIPPED_THRESHOLD.value == "skipped_threshold"
831832
assert IndexOperation.FAILED.value == "failed"
@@ -834,6 +835,7 @@ def test_enum_values(self):
834835
def test_enum_string_conversion(self):
835836
"""Test that enum converts to string correctly."""
836837
assert str(IndexOperation.CREATED) == "created"
838+
assert str(IndexOperation.READY) == "ready"
837839
assert str(IndexOperation.SKIPPED) == "skipped"
838840
assert str(IndexOperation.SKIPPED_THRESHOLD) == "skipped_threshold"
839841
assert str(IndexOperation.FAILED) == "failed"
@@ -843,6 +845,7 @@ def test_enum_value_access(self):
843845
"""Test that enum values can be accessed correctly."""
844846
# Test that enum instances are not equal to strings (type safety)
845847
assert IndexOperation.CREATED != "created"
848+
assert IndexOperation.READY != "ready"
846849
assert IndexOperation.SKIPPED != "skipped"
847850
assert IndexOperation.SKIPPED_THRESHOLD != "skipped_threshold"
848851
assert IndexOperation.FAILED != "failed"
@@ -852,12 +855,14 @@ def test_enum_membership(self):
852855
"""Test enum membership checks."""
853856
operations = {
854857
IndexOperation.CREATED,
858+
IndexOperation.READY,
855859
IndexOperation.SKIPPED,
856860
IndexOperation.SKIPPED_THRESHOLD,
857861
IndexOperation.FAILED,
858862
IndexOperation.UPDATED,
859863
}
860864
assert IndexOperation.CREATED in operations
865+
assert IndexOperation.READY in operations
861866
assert IndexOperation.SKIPPED in operations
862867
assert IndexOperation.SKIPPED_THRESHOLD in operations
863868
assert IndexOperation.FAILED in operations

0 commit comments

Comments
 (0)