Skip to content

Commit d1efa06

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f5f43a4 commit d1efa06

File tree

8 files changed

+107
-89
lines changed

8 files changed

+107
-89
lines changed

engine/clients/cassandra/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ Run the following command to start the server (alternatively, run `docker compos
99
$ docker compose up
1010

1111
[+] Running 1/1
12-
✔ cassandra Pulled 1.4s
12+
✔ cassandra Pulled 1.4s
1313
[+] Running 1/1
14-
✔ Container cassandra-benchmark Recreated 0.1s
14+
✔ Container cassandra-benchmark Recreated 0.1s
1515
Attaching to cassandra-benchmark
1616
cassandra-benchmark | CompileCommand: dontinline org/apache/cassandra/db/Columns$Serializer.deserializeLargeSubset(Lorg/apache/cassandra/io/util/DataInputPlus;Lorg/apache/cassandra/db/Columns;I)Lorg/apache/cassandra/db/Columns; bool dontinline = true
1717
...
@@ -29,7 +29,7 @@ cassandra-benchmark | INFO [main] 2025-04-04 22:28:25,091 StorageService.java:
2929
### Start up the client benchmark
3030
Run the following command to start the client benchmark using `glove-25-angular` dataset as an example:
3131
```bash
32-
% python3 -m run --engines cassandra-single-node --datasets glove-25-angular
32+
% python3 -m run --engines cassandra-single-node --datasets glove-25-angular
3333
```
3434
and you'll see the following output:
3535
```bash

engine/clients/cassandra/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,4 @@
22
from engine.clients.cassandra.search import CassandraSearcher
33
from engine.clients.cassandra.upload import CassandraUploader
44

5-
__all__ = [
6-
"CassandraConfigurator",
7-
"CassandraSearcher",
8-
"CassandraUploader"
9-
]
5+
__all__ = ["CassandraConfigurator", "CassandraSearcher", "CassandraUploader"]

engine/clients/cassandra/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
CASSANDRA_TABLE = os.getenv("CASSANDRA_TABLE", "vectors")
55
ASTRA_API_ENDPOINT = os.getenv("ASTRA_API_ENDPOINT", None)
66
ASTRA_API_KEY = os.getenv("ASTRA_API_KEY", None)
7-
ASTRA_SCB_PATH = os.getenv("ASTRA_SCB_PATH", None)
7+
ASTRA_SCB_PATH = os.getenv("ASTRA_SCB_PATH", None)

engine/clients/cassandra/configure.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT
2-
from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy, ExponentialReconnectionPolicy
31
from cassandra import ConsistencyLevel, ProtocolVersion
2+
from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile
3+
from cassandra.policies import (
4+
DCAwareRoundRobinPolicy,
5+
ExponentialReconnectionPolicy,
6+
TokenAwarePolicy,
7+
)
48

59
from benchmark.dataset import Dataset
610
from engine.base_client.configure import BaseConfigurator
@@ -13,26 +17,28 @@ class CassandraConfigurator(BaseConfigurator):
1317
DISTANCE_MAPPING = {
1418
Distance.L2: "euclidean",
1519
Distance.COSINE: "cosine",
16-
Distance.DOT: "dot_product"
20+
Distance.DOT: "dot_product",
1721
}
1822

1923
def __init__(self, host, collection_params: dict, connection_params: dict):
2024
super().__init__(host, collection_params, connection_params)
21-
25+
2226
# Set up execution profiles for consistency and performance
2327
profile = ExecutionProfile(
2428
load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()),
2529
consistency_level=ConsistencyLevel.LOCAL_QUORUM,
26-
request_timeout=60
30+
request_timeout=60,
2731
)
28-
32+
2933
# Initialize Cassandra cluster connection
3034
self.cluster = Cluster(
3135
contact_points=[host],
3236
execution_profiles={EXEC_PROFILE_DEFAULT: profile},
3337
protocol_version=ProtocolVersion.V4,
34-
reconnection_policy=ExponentialReconnectionPolicy(base_delay=1, max_delay=60),
35-
**connection_params
38+
reconnection_policy=ExponentialReconnectionPolicy(
39+
base_delay=1, max_delay=60
40+
),
41+
**connection_params,
3642
)
3743
self.session = self.cluster.connect()
3844

@@ -44,17 +50,17 @@ def recreate(self, dataset: Dataset, collection_params):
4450
"""Create keyspace and table for vector search"""
4551
# Create keyspace if not exists
4652
self.session.execute(
47-
f"""CREATE KEYSPACE IF NOT EXISTS {CASSANDRA_KEYSPACE}
53+
f"""CREATE KEYSPACE IF NOT EXISTS {CASSANDRA_KEYSPACE}
4854
WITH REPLICATION = {{ 'class': 'SimpleStrategy', 'replication_factor': 1 }}"""
4955
)
50-
56+
5157
# Use the keyspace
5258
self.session.execute(f"USE {CASSANDRA_KEYSPACE}")
53-
59+
5460
# Get the distance metric
5561
distance_metric = self.DISTANCE_MAPPING.get(dataset.config.distance)
5662
vector_size = dataset.config.vector_size
57-
63+
5864
# Create vector table
5965
# Using a simple schema that supports vector similarity search
6066
self.session.execute(
@@ -64,14 +70,14 @@ def recreate(self, dataset: Dataset, collection_params):
6470
metadata map<text, text>
6571
)"""
6672
)
67-
73+
6874
# Create vector index using the appropriate distance metric
6975
self.session.execute(
70-
f"""CREATE CUSTOM INDEX IF NOT EXISTS vector_index ON {CASSANDRA_TABLE}(embedding)
71-
USING 'StorageAttachedIndex'
76+
f"""CREATE CUSTOM INDEX IF NOT EXISTS vector_index ON {CASSANDRA_TABLE}(embedding)
77+
USING 'StorageAttachedIndex'
7278
WITH OPTIONS = {{ 'similarity_function': '{distance_metric}' }}"""
7379
)
74-
80+
7581
# Add additional schema fields based on collection_params if needed
7682
for field_name, field_type in dataset.config.schema.items():
7783
if field_type in ["keyword", "text"]:
@@ -81,7 +87,7 @@ def recreate(self, dataset: Dataset, collection_params):
8187
# For numeric fields that need separate indexing
8288
# In a real implementation, we might alter the table to add these columns
8389
pass
84-
90+
8591
return collection_params
8692

8793
def execution_params(self, distance, vector_size) -> dict:
@@ -90,7 +96,7 @@ def execution_params(self, distance, vector_size) -> dict:
9096

9197
def delete_client(self):
9298
"""Close the Cassandra connection"""
93-
if hasattr(self, 'session') and self.session:
99+
if hasattr(self, "session") and self.session:
94100
self.session.shutdown()
95-
if hasattr(self, 'cluster') and self.cluster:
96-
self.cluster.shutdown()
101+
if hasattr(self, "cluster") and self.cluster:
102+
self.cluster.shutdown()

engine/clients/cassandra/parser.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,23 @@ def build_condition(
1111
Build a CQL condition expression that combines AND and OR subfilters
1212
"""
1313
conditions = []
14-
14+
1515
# Add AND conditions
1616
if and_subfilters and len(and_subfilters) > 0:
1717
and_conds = " AND ".join([f"({cond})" for cond in and_subfilters if cond])
1818
if and_conds:
1919
conditions.append(f"({and_conds})")
20-
20+
2121
# Add OR conditions
2222
if or_subfilters and len(or_subfilters) > 0:
2323
or_conds = " OR ".join([f"({cond})" for cond in or_subfilters if cond])
2424
if or_conds:
2525
conditions.append(f"({or_conds})")
26-
26+
2727
# Combine all conditions
2828
if not conditions:
2929
return None
30-
30+
3131
return " AND ".join(conditions)
3232

3333
def build_exact_match_filter(self, field_name: str, value: FieldValue) -> Any:
@@ -52,31 +52,31 @@ def build_range_filter(
5252
Build a CQL range filter condition
5353
"""
5454
conditions = []
55-
55+
5656
if lt is not None:
5757
if isinstance(lt, str):
5858
conditions.append(f"metadata['{field_name}'] < '{lt}'")
5959
else:
6060
conditions.append(f"metadata['{field_name}'] < '{str(lt)}'")
61-
61+
6262
if gt is not None:
6363
if isinstance(gt, str):
6464
conditions.append(f"metadata['{field_name}'] > '{gt}'")
6565
else:
6666
conditions.append(f"metadata['{field_name}'] > '{str(gt)}'")
67-
67+
6868
if lte is not None:
6969
if isinstance(lte, str):
7070
conditions.append(f"metadata['{field_name}'] <= '{lte}'")
7171
else:
7272
conditions.append(f"metadata['{field_name}'] <= '{str(lte)}'")
73-
73+
7474
if gte is not None:
7575
if isinstance(gte, str):
7676
conditions.append(f"metadata['{field_name}'] >= '{gte}'")
7777
else:
7878
conditions.append(f"metadata['{field_name}'] >= '{str(gte)}'")
79-
79+
8080
return " AND ".join(conditions)
8181

8282
def build_geo_filter(
@@ -89,4 +89,4 @@ def build_geo_filter(
8989
"""
9090
# In a real implementation with a geo extension, we'd implement proper geo filtering
9191
# For this benchmark, we'll return a placeholder condition that doesn't filter
92-
return "1=1" # Always true condition as a placeholder
92+
return "1=1" # Always true condition as a placeholder

engine/clients/cassandra/search.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import multiprocessing as mp
22
from typing import List, Tuple
33

4-
from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT
5-
from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy, ExponentialReconnectionPolicy
64
from cassandra import ConsistencyLevel, ProtocolVersion
5+
from cassandra.cluster import EXEC_PROFILE_DEFAULT, Cluster, ExecutionProfile
6+
from cassandra.policies import (
7+
DCAwareRoundRobinPolicy,
8+
ExponentialReconnectionPolicy,
9+
TokenAwarePolicy,
10+
)
711

812
from dataset_reader.base_reader import Query
913
from engine.base_client.distances import Distance
@@ -24,20 +28,22 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
2428
profile = ExecutionProfile(
2529
load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()),
2630
consistency_level=ConsistencyLevel.LOCAL_ONE, # Use LOCAL_ONE for faster reads
27-
request_timeout=60
31+
request_timeout=60,
2832
)
29-
33+
3034
# Initialize Cassandra cluster connection
3135
cls.cluster = Cluster(
32-
contact_points=[host],
36+
contact_points=[host],
3337
execution_profiles={EXEC_PROFILE_DEFAULT: profile},
34-
reconnection_policy=ExponentialReconnectionPolicy(base_delay=1, max_delay=60),
38+
reconnection_policy=ExponentialReconnectionPolicy(
39+
base_delay=1, max_delay=60
40+
),
3541
protocol_version=ProtocolVersion.V4,
36-
**connection_params
42+
**connection_params,
3743
)
3844
cls.session = cls.cluster.connect(CASSANDRA_KEYSPACE)
3945
cls.search_params = search_params
40-
46+
4147
# Update prepared statements with current search parameters
4248
cls.update_prepared_statements(distance)
4349

@@ -50,7 +56,7 @@ def update_prepared_statements(cls, distance):
5056
"""Create prepared statements for vector searches"""
5157
# Prepare a vector similarity search query
5258
limit = cls.search_params.get("top", 10)
53-
59+
5460
if distance == Distance.COSINE:
5561
SIMILARITY_FUNC = "similarity_cosine"
5662
elif distance == Distance.L2:
@@ -61,48 +67,49 @@ def update_prepared_statements(cls, distance):
6167
raise ValueError(f"Unsupported distance metric: {distance}")
6268

6369
cls.ann_search_stmt = cls.session.prepare(
64-
f"""SELECT id, {SIMILARITY_FUNC}(embedding, ?) as distance
65-
FROM {CASSANDRA_TABLE}
70+
f"""SELECT id, {SIMILARITY_FUNC}(embedding, ?) as distance
71+
FROM {CASSANDRA_TABLE}
6672
ORDER BY embedding ANN OF ?
6773
LIMIT {limit}"""
6874
)
69-
75+
7076
# Prepare a statement for filtered vector search
71-
cls.filtered_search_query_template = (
72-
f"""SELECT id, {SIMILARITY_FUNC}(embedding, ?) as distance
73-
FROM {CASSANDRA_TABLE}
77+
cls.filtered_search_query_template = f"""SELECT id, {SIMILARITY_FUNC}(embedding, ?) as distance
78+
FROM {CASSANDRA_TABLE}
7479
WHERE {{conditions}}
7580
ORDER BY embedding ANN OF ?
7681
LIMIT {limit}"""
77-
)
7882

7983
@classmethod
8084
def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
8185
"""Execute a vector similarity search with optional filters"""
8286
# Convert query vector to a format Cassandra can use
83-
query_vector = query.vector.tolist() if hasattr(query.vector, 'tolist') else query.vector
84-
87+
query_vector = (
88+
query.vector.tolist() if hasattr(query.vector, "tolist") else query.vector
89+
)
90+
8591
# Generate filter conditions if metadata conditions exist
8692
filter_conditions = cls.parser.parse(query.meta_conditions)
87-
93+
8894
try:
8995
if filter_conditions:
9096
# Use the filtered search query
91-
query_with_conditions = cls.filtered_search_query_template.format(conditions=filter_conditions)
97+
query_with_conditions = cls.filtered_search_query_template.format(
98+
conditions=filter_conditions
99+
)
92100
results = cls.session.execute(
93101
cls.session.prepare(query_with_conditions),
94-
(query_vector, query_vector)
102+
(query_vector, query_vector),
95103
)
96104
else:
97105
# Use the basic ANN search query
98106
results = cls.session.execute(
99-
cls.ann_search_stmt,
100-
(query_vector, query_vector)
107+
cls.ann_search_stmt, (query_vector, query_vector)
101108
)
102-
109+
103110
# Extract and return results
104111
return [(row.id, row.distance) for row in results]
105-
112+
106113
except Exception as ex:
107114
print(f"Error during Cassandra vector search: {ex}")
108115
raise ex
@@ -113,4 +120,4 @@ def delete_client(cls):
113120
if cls.session:
114121
cls.session.shutdown()
115122
if cls.cluster:
116-
cls.cluster.shutdown()
123+
cls.cluster.shutdown()

0 commit comments

Comments
 (0)