Skip to content

Commit 9500853

Browse files
justin-cechmanektylerhutcherson
authored andcommitted
wip: adding Text and Hybrid queries
1 parent 0b3a5ce commit 9500853

File tree

1 file changed

+210
-2
lines changed

1 file changed

+210
-2
lines changed

redisvl/query/query.py

Lines changed: 210 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def __init__(
9191
num_results (Optional[int], optional): The number of results to return. Defaults to 10.
9292
dialect (int, optional): The query dialect. Defaults to 2.
9393
sort_by (Optional[str], optional): The field to order the results by. Defaults to None.
94-
in_order (bool, optional): Requires the terms in the field to have the same order as the terms in the query filter. Defaults to False.
94+
in_order (bool, optional): Requires the terms in the field to have the same order as the
95+
terms in the query filter. Defaults to False.
9596
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
9697
9798
Raises:
@@ -136,7 +137,8 @@ def __init__(
136137
"""A query for a simple count operation provided some filter expression.
137138
138139
Args:
139-
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to query with. Defaults to None.
140+
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
141+
query with. Defaults to None.
140142
params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
141143
142144
Raises:
@@ -214,6 +216,7 @@ def __init__(
214216
"float32".
215217
num_results (int, optional): The top k results to return from the
216218
vector search. Defaults to 10.
219+
217220
return_score (bool, optional): Whether to return the vector
218221
distance. Defaults to True.
219222
dialect (int, optional): The RediSearch query dialect.
@@ -647,3 +650,208 @@ def params(self) -> Dict[str, Any]:
647650
class RangeQuery(VectorRangeQuery):
648651
# keep for backwards compatibility
649652
pass
653+
654+
655+
class TextQuery(FilterQuery):
656+
def __init__(
657+
self,
658+
text: str,
659+
text_field: str,
660+
text_scorer: str = "TFIDF",
661+
return_fields: Optional[List[str]] = None,
662+
filter_expression: Optional[Union[str, FilterExpression]] = None,
663+
num_results: int = 10,
664+
return_score: bool = True,
665+
dialect: int = 2,
666+
sort_by: Optional[str] = None,
667+
in_order: bool = False,
668+
):
669+
"""A query for running a full text and vector search, along with an optional
670+
filter expression.
671+
672+
Args:
673+
text (str): The text string to perform the text search with.
674+
text_field (str): The name of the document field to perform text search on.
675+
text_scorer (str, optional): The text scoring algorithm to use.
676+
Defaults to TFIDF. Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
677+
See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
678+
return_fields (List[str]): The declared fields to return with search
679+
results.
680+
filter_expression (Union[str, FilterExpression], optional): A filter to apply
681+
along with the vector search. Defaults to None.
682+
num_results (int, optional): The top k results to return from the
683+
search. Defaults to 10.
684+
return_score (bool, optional): Whether to return the text score.
685+
Defaults to True.
686+
dialect (int, optional): The RediSearch query dialect.
687+
Defaults to 2.
688+
sort_by (Optional[str]): The field to order the results by. Defaults
689+
to None. Results will be ordered by text score.
690+
in_order (bool): Requires the terms in the field to have
691+
the same order as the terms in the query filter, regardless of
692+
the offsets between them. Defaults to False.
693+
694+
Raises:
695+
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
696+
"""
697+
self._text_field = text_field
698+
self._num_results = num_results
699+
self.set_filter(filter_expression)
700+
query_string = self._build_query_string()
701+
from nltk.corpus import stopwords
702+
import nltk
703+
704+
nltk.download('stopwords')
705+
self._stopwords = set(stopwords.words('english'))
706+
707+
708+
super().__init__(query_string)
709+
710+
# Handle query modifiers
711+
if return_fields:
712+
self.return_fields(*return_fields)
713+
714+
self.paging(0, self._num_results).dialect(dialect)
715+
716+
if return_score:
717+
self.return_fields(self.DISTANCE_ID) #TODO
718+
719+
if sort_by:
720+
self.sort_by(sort_by)
721+
else:
722+
self.sort_by(self.DISTANCE_ID) #TODO
723+
724+
if in_order:
725+
self.in_order()
726+
727+
728+
def _tokenize_query(self, user_query: str) -> str:
729+
"""Convert a raw user query to a redis full text query joined by ORs"""
730+
words = word_tokenize(user_query)
731+
732+
tokens = [token.strip().strip(",").lower() for token in user_query.split()]
733+
return " | ".join([token for token in tokens if token not in self._stopwords])
734+
735+
736+
def _build_query_string(self) -> str:
737+
"""Build the full query string for text search with optional filtering."""
738+
filter_expression = self._filter_expression
739+
# TODO include text only
740+
if isinstance(filter_expression, FilterExpression):
741+
filter_expression = str(filter_expression)
742+
743+
text = f"(~{Text(self._text_field) % self._tokenize_query(user_query)})"
744+
745+
text_and_filter = text & self._filter_expression
746+
747+
#TODO is this method even needed? use
748+
return text_and_filter
749+
750+
751+
class HybridQuery(VectorQuery, TextQuery):
752+
def __init__():
753+
self,
754+
text: str,
755+
text_field: str,
756+
vector: Union[List[float], bytes],
757+
vector_field_name: str,
758+
text_scorer: str = "TFIDF",
759+
alpha: float = 0.7,
760+
return_fields: Optional[List[str]] = None,
761+
filter_expression: Optional[Union[str, FilterExpression]] = None,
762+
dtype: str = "float32",
763+
num_results: int = 10,
764+
return_score: bool = True,
765+
dialect: int = 2,
766+
sort_by: Optional[str] = None,
767+
in_order: bool = False,
768+
):
769+
"""A query for running a hybrid full text and vector search, along with
770+
an optional filter expression.
771+
772+
Args:
773+
text (str): The text string to run text search with.
774+
text_field (str): The name of the text field to search against.
775+
vector (List[float]): The vector to perform the vector search with.
776+
vector_field_name (str): The name of the vector field to search
777+
against in the database.
778+
text_scorer (str, optional): The text scoring algorithm to use.
779+
Defaults to TFIDF.
780+
alpha (float, optional): The amount to weight the vector similarity
781+
score relative to the text similarity score. Defaults to 0.7
782+
return_fields (List[str]): The declared fields to return with search
783+
results.
784+
filter_expression (Union[str, FilterExpression], optional): A filter to apply
785+
along with the vector search. Defaults to None.
786+
dtype (str, optional): The dtype of the vector. Defaults to
787+
"float32".
788+
num_results (int, optional): The top k results to return from the
789+
vector search. Defaults to 10.
790+
return_score (bool, optional): Whether to return the vector
791+
distance. Defaults to True.
792+
dialect (int, optional): The RediSearch query dialect.
793+
Defaults to 2.
794+
sort_by (Optional[str]): The field to order the results by. Defaults
795+
to None. Results will be ordered by vector distance.
796+
in_order (bool): Requires the terms in the field to have
797+
the same order as the terms in the query filter, regardless of
798+
the offsets between them. Defaults to False.
799+
800+
Raises:
801+
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
802+
803+
Note:
804+
Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search
805+
"""
806+
self._text = text
807+
self._text_field_name = tex_field_name
808+
self._vector = vector
809+
self._vector_field_name = vector_field_name
810+
self._dtype = dtype
811+
self._num_results = num_results
812+
self.set_filter(filter_expression)
813+
query_string = self._build_query_string()
814+
815+
# TODO how to handle multiple parents? call parent.__init__() manually?
816+
super().__init__(query_string)
817+
818+
# Handle query modifiers
819+
if return_fields:
820+
self.return_fields(*return_fields)
821+
822+
self.paging(0, self._num_results).dialect(dialect)
823+
824+
if return_score:
825+
self.return_fields(self.DISTANCE_ID)
826+
827+
if sort_by:
828+
self.sort_by(sort_by)
829+
else:
830+
self.sort_by(self.DISTANCE_ID)
831+
832+
if in_order:
833+
self.in_order()
834+
835+
836+
def _build_query_string(self) -> str:
837+
"""Build the full query string for hybrid search with optional filtering."""
838+
filter_expression = self._filter_expression
839+
# TODO include hybrid
840+
if isinstance(filter_expression, FilterExpression):
841+
filter_expression = str(filter_expression)
842+
return f"{filter_expression}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]"
843+
844+
@property
845+
def params(self) -> Dict[str, Any]:
846+
"""Return the parameters for the query.
847+
848+
Returns:
849+
Dict[str, Any]: The parameters for the query.
850+
"""
851+
if isinstance(self._vector, bytes):
852+
vector = self._vector
853+
else:
854+
vector = array_to_buffer(self._vector, dtype=self._dtype)
855+
856+
return {self.VECTOR_PARAM: vector}
857+

0 commit comments

Comments
 (0)