Skip to content

Commit 75c5b51

Browse files
authored
Refactor/shard key helpers (#19765)
* Refactor shard key creation into separate methods for improved readability and maintainability * Add tests for shard key and payload index creation in QdrantVectorStore * Bump version to 0.8.3 in pyproject.toml * chore: Import`UnexpectedResponse` on module level instead of inside multiple functions
1 parent 05a6937 commit 75c5b51

File tree

3 files changed

+89
-47
lines changed

3 files changed

+89
-47
lines changed

llama-index-integrations/vector_stores/llama-index-vector-stores-qdrant/llama_index/vector_stores/qdrant/base.py

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
IsEmptyCondition,
5353
)
5454
from qdrant_client.qdrant_fastembed import IDF_EMBEDDING_MODELS
55+
from qdrant_client.http.exceptions import UnexpectedResponse
5556

5657
logger = logging.getLogger(__name__)
5758
import_err_msg = (
@@ -575,8 +576,6 @@ async def async_add(
575576
ValueError: If trying to using async methods without aclient
576577
577578
"""
578-
from qdrant_client.http.exceptions import UnexpectedResponse
579-
580579
self._ensure_async_client()
581580

582581
collection_initialized = await self._acollection_exists(self.collection_name)
@@ -799,9 +798,6 @@ def client(self) -> Any:
799798

800799
def _create_collection(self, collection_name: str, vector_size: int) -> None:
801800
"""Create a Qdrant collection."""
802-
from qdrant_client.http import models as rest
803-
from qdrant_client.http.exceptions import UnexpectedResponse
804-
805801
dense_config = self._dense_config or rest.VectorParams(
806802
size=vector_size,
807803
distance=rest.Distance.COSINE,
@@ -843,11 +839,7 @@ def _create_collection(self, collection_name: str, vector_size: int) -> None:
843839
)
844840

845841
if self._shard_keys:
846-
for shard_key in self._shard_keys:
847-
self._client.create_shard_key(
848-
collection_name=collection_name,
849-
shard_key=shard_key,
850-
)
842+
self._create_shard_keys()
851843

852844
# To improve search performance Qdrant recommends setting up
853845
# a payload index for fields used in filters.
@@ -870,30 +862,15 @@ def _create_collection(self, collection_name: str, vector_size: int) -> None:
870862
)
871863

872864
if self._shard_keys:
873-
for shard_key in self._shard_keys:
874-
try:
875-
self._client.create_shard_key(
876-
collection_name=collection_name,
877-
shard_key=shard_key,
878-
)
879-
except (RpcError, ValueError, UnexpectedResponse) as exc:
880-
if "already exists" not in str(exc):
881-
raise exc # noqa: TRY201
882-
logger.warning(
883-
"Shard key %s already exists, skipping creation.",
884-
shard_key,
885-
)
886-
continue
865+
self._create_shard_keys()
866+
887867
if self._payload_indexes:
888868
self._create_payload_indexes()
889869

890870
self._collection_initialized = True
891871

892872
async def _acreate_collection(self, collection_name: str, vector_size: int) -> None:
893873
"""Asynchronous method to create a Qdrant collection."""
894-
from qdrant_client.http import models as rest
895-
from qdrant_client.http.exceptions import UnexpectedResponse
896-
897874
dense_config = self._dense_config or rest.VectorParams(
898875
size=vector_size,
899876
distance=rest.Distance.COSINE,
@@ -932,11 +909,7 @@ async def _acreate_collection(self, collection_name: str, vector_size: int) -> N
932909
)
933910

934911
if self._shard_keys:
935-
for shard_key in self._shard_keys:
936-
await self._aclient.create_shard_key(
937-
collection_name=collection_name,
938-
shard_key=shard_key,
939-
)
912+
await self._acreate_shard_keys()
940913

941914
# To improve search performance Qdrant recommends setting up
942915
# a payload index for fields used in filters.
@@ -959,20 +932,8 @@ async def _acreate_collection(self, collection_name: str, vector_size: int) -> N
959932
)
960933

961934
if self._shard_keys:
962-
for shard_key in self._shard_keys:
963-
try:
964-
await self._client.create_shard_key(
965-
collection_name=collection_name,
966-
shard_key=shard_key,
967-
)
968-
except (RpcError, ValueError, UnexpectedResponse) as exc:
969-
if "already exists" not in str(exc):
970-
raise exc # noqa: TRY201
971-
logger.warning(
972-
"Shard key %s already exists, skipping creation.",
973-
shard_key,
974-
)
975-
continue
935+
await self._acreate_shard_keys()
936+
976937
if self._payload_indexes:
977938
await self._acreate_payload_indexes()
978939

@@ -986,6 +947,46 @@ async def _acollection_exists(self, collection_name: str) -> bool:
986947
"""Asynchronous method to check if a collection exists."""
987948
return await self._aclient.collection_exists(collection_name)
988949

950+
def _create_shard_keys(self) -> None:
951+
"""Create shard keys in Qdrant collection."""
952+
if not self._shard_keys:
953+
return
954+
955+
for shard_key in self._shard_keys:
956+
try:
957+
self._client.create_shard_key(
958+
collection_name=self.collection_name,
959+
shard_key=shard_key,
960+
)
961+
except (RpcError, ValueError, UnexpectedResponse) as exc:
962+
if "already exists" not in str(exc):
963+
raise exc # noqa: TRY201
964+
logger.warning(
965+
"Shard key %s already exists, skipping creation.",
966+
shard_key,
967+
)
968+
continue
969+
970+
async def _acreate_shard_keys(self) -> None:
971+
"""Asynchronous method to create shard keys in Qdrant collection."""
972+
if not self._shard_keys:
973+
return
974+
975+
for shard_key in self._shard_keys:
976+
try:
977+
await self._aclient.create_shard_key(
978+
collection_name=self.collection_name,
979+
shard_key=shard_key,
980+
)
981+
except (RpcError, ValueError, UnexpectedResponse) as exc:
982+
if "already exists" not in str(exc):
983+
raise exc # noqa: TRY201
984+
logger.warning(
985+
"Shard key %s already exists, skipping creation.",
986+
shard_key,
987+
)
988+
continue
989+
989990
def _create_payload_indexes(self) -> None:
990991
"""Create payload indexes in Qdrant collection."""
991992
if not self._payload_indexes:

llama-index-integrations/vector_stores/llama-index-vector-stores-qdrant/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dev = [
2828

2929
[project]
3030
name = "llama-index-vector-stores-qdrant"
31-
version = "0.8.2"
31+
version = "0.8.3"
3232
description = "llama-index vector_stores qdrant integration"
3333
authors = [{name = "Your Name", email = "[email protected]"}]
3434
requires-python = ">=3.9,<3.14"

llama-index-integrations/vector_stores/llama-index-vector-stores-qdrant/tests/test_vector_stores_qdrant.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,3 +653,44 @@ def test_payload_indexes_created_on_existing_collection_sync(
653653
missing = schema_keys - payload_keys
654654
extra = payload_keys - schema_keys
655655
assert not missing and not extra, f"Missing: {missing}, Extra: {extra}"
656+
657+
658+
@requires_qdrant_cluster
659+
@pytest.mark.asyncio
660+
async def test_arecreate_shard_keys(
661+
shard_vector_store: QdrantVectorStore,
662+
):
663+
await shard_vector_store._acreate_shard_keys()
664+
665+
666+
@requires_qdrant_cluster
667+
def test_recreate_shard_keys(
668+
shard_vector_store: QdrantVectorStore,
669+
):
670+
shard_vector_store._create_shard_keys()
671+
672+
673+
@pytest.mark.asyncio
674+
async def test_acreate_shard_keys_returns_early_when_no_shard_keys(
675+
vector_store: QdrantVectorStore,
676+
):
677+
await vector_store._acreate_shard_keys()
678+
679+
680+
def test_create_shard_keys_returns_early_when_no_shard_keys(
681+
vector_store: QdrantVectorStore,
682+
):
683+
vector_store._create_shard_keys()
684+
685+
686+
@pytest.mark.asyncio
687+
async def test_acreate_payload_indexes_returns_early_when_no_payload_indexes(
688+
vector_store: QdrantVectorStore,
689+
):
690+
await vector_store._acreate_payload_indexes()
691+
692+
693+
def test_create_payload_indexes_returns_early_when_no_payload_indexes(
694+
vector_store: QdrantVectorStore,
695+
):
696+
vector_store._create_payload_indexes()

0 commit comments

Comments
 (0)