Skip to content

Commit 10f4474

Browse files
cleans text and hybrid tests
1 parent 9348583 commit 10f4474

File tree

4 files changed

+28
-30
lines changed

4 files changed

+28
-30
lines changed

redisvl/index/index.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
4343
from redisvl.query import (
4444
AggregationQuery,
45+
BaseVectorQuery,
4546
BaseQuery,
4647
CountQuery,
4748
FilterQuery,

redisvl/query/query.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from redis.commands.search.query import Query as RedisQuery
55

66
from redisvl.query.filter import FilterExpression
7-
from redisvl.redis.utils import array_to_buffer, denorm_cosine_distance
7+
from redisvl.redis.utils import array_to_buffer
88
from redisvl.utils.token_escaper import TokenEscaper
9+
from redisvl.utils.utils import denorm_cosine_distance
910

1011

1112
class BaseQuery(RedisQuery):
@@ -811,7 +812,7 @@ def _build_query_string(self) -> str:
811812
else:
812813
filter_expression = ""
813814

814-
text = f"~(@{self._text_field}:({self.tokenize_and_escape_query(self._text)}))"
815+
text = f"(~@{self._text_field}:({self.tokenize_and_escape_query(self._text)}))"
815816
if filter_expression and filter_expression != "*":
816817
text += f"({filter_expression})"
817818
return text

tests/unit/test_aggregation_types.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313

1414

1515
# Test Cases
16-
17-
1816
def test_aggregate_hybrid_query():
1917
text_field_name = "description"
2018
vector_field_name = "embedding"
@@ -34,10 +32,6 @@ def test_aggregate_hybrid_query():
3432
assert hybrid_query._vector == sample_vector
3533
assert hybrid_query._vector_field == vector_field_name
3634
assert hybrid_query._scorer == "BM25STD"
37-
# assert (
38-
# hybrid_query.filter
39-
# == f"(@{text_field_name}:({hybrid_query.tokenize_and_escape_query(text_string)}))"
40-
# )
4135
assert hybrid_query._filter_expression == None
4236
assert hybrid_query._alpha == 0.7
4337
assert hybrid_query._num_results == 10
@@ -72,47 +66,50 @@ def test_aggregate_hybrid_query():
7266
assert hybrid_query._vector == sample_vector
7367
assert hybrid_query._vector_field == vector_field_name
7468
assert hybrid_query._scorer == scorer
75-
# assert (
76-
# hybrid_query.filter
77-
# == f"(@{text_field_name}:({hybrid_query.tokenize_and_escape_query(text_string)}))"
78-
# )
7969
assert hybrid_query._filter_expression == filter_expression
8070
assert hybrid_query._alpha == 0.5
8171
assert hybrid_query._num_results == 8
8272
assert hybrid_query._loadfields == return_fields
8373
assert hybrid_query._dialect == 2
8474
assert hybrid_query.stopwords == set()
8575

86-
return
87-
8876
# Test stopwords are configurable
89-
hybrid_query = HybridAggregationQuery(text_string, text_field_name, stopwords=None)
77+
hybrid_query = HybridAggregationQuery(
78+
sample_text, text_field_name, sample_vector, vector_field_name, stopwords=None
79+
)
9080
assert hybrid_query.stopwords == set([])
91-
# assert (
92-
# hyrid_query.filter
93-
# == f"(@{text_field_name}:({hyrid_query.tokenize_and_escape_query(text_string)}))"
94-
# )
9581

96-
hyrid_query = HybridAggregationQuery(
97-
text_string, text_field_name, stopwords=["the", "a", "of"]
82+
hybrid_query = HybridAggregationQuery(
83+
sample_text,
84+
text_field_name,
85+
sample_vector,
86+
vector_field_name,
87+
stopwords=["the", "a", "of"],
9888
)
9989
assert hybrid_query.stopwords == set(["the", "a", "of"])
100-
assert (
101-
hybrid_query.filter
102-
== f"(@{text_field_name}:({hybrid_query.tokenize_and_escape_query(text_string)}))"
103-
)
104-
10590
hybrid_query = HybridAggregationQuery(
106-
text_string, text_field_name, stopwords="german"
91+
sample_text,
92+
text_field_name,
93+
sample_vector,
94+
vector_field_name,
95+
stopwords="german",
10796
)
10897
assert hybrid_query.stopwords != set([])
10998

11099
with pytest.raises(ValueError):
111100
hybrid_query = HybridAggregationQuery(
112-
text_string, text_field_name, stopwords="gibberish"
101+
sample_text,
102+
text_field_name,
103+
sample_vector,
104+
vector_field_name,
105+
stopwords="gibberish",
113106
)
114107

115108
with pytest.raises(TypeError):
116109
hybrid_query = HybridAggregationQuery(
117-
text_string, text_field_name, stopwords=[1, 2, 3]
110+
sample_text,
111+
text_field_name,
112+
sample_vector,
113+
vector_field_name,
114+
stopwords=[1, 2, 3],
118115
)

tests/unit/test_query_types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def test_text_query():
270270
with pytest.raises(TypeError):
271271
text_query = TextQuery(text_string, text_field_name, stopwords=[1, 2, 3])
272272

273-
274273
text_query = TextQuery(text_string, text_field_name, stopwords=["the", "a", "of"])
275274
assert text_query.stopwords == set(["the", "a", "of"])
276275
assert (

0 commit comments

Comments
 (0)