11from enum import Enum
22from typing import Any , Dict , List , Optional , Union
33
4+ from redis .commands .search .aggregation import AggregateRequest , Desc
45from redis .commands .search .query import Query as RedisQuery
56
67from redisvl .query .filter import FilterExpression
@@ -137,7 +138,7 @@ def __init__(
137138 """A query for a simple count operation provided some filter expression.
138139
139140 Args:
140- filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
141+ filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
141142 query with. Defaults to None.
142143 params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
143144
@@ -654,31 +655,32 @@ class RangeQuery(VectorRangeQuery):
654655
655656class TextQuery (FilterQuery ):
656657 def __init__ (
657- self ,
658+ self ,
658659 text : str ,
659660 text_field : str ,
660- text_scorer : str = "TFIDF" ,
661- return_fields : Optional [List [str ]] = None ,
661+ text_scorer : str = "BM25" ,
662662 filter_expression : Optional [Union [str , FilterExpression ]] = None ,
663+ return_fields : Optional [List [str ]] = None ,
663664 num_results : int = 10 ,
664665 return_score : bool = True ,
665666 dialect : int = 2 ,
666667 sort_by : Optional [str ] = None ,
667668 in_order : bool = False ,
669+ params : Optional [Dict [str , Any ]] = None ,
668670 ):
669671 """A query for running a full text and vector search, along with an optional
670672 filter expression.
671673
672674 Args:
673- text (str): The text string to perform the text search with.
675+ text (str): The text string to perform the text search with.
674676 text_field (str): The name of the document field to perform text search on.
675677 text_scorer (str, optional): The text scoring algorithm to use.
676- Defaults to TFIDF . Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
678+ Defaults to BM25 . Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
677679 See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
680+ filter_expression (Union[str, FilterExpression], optional): A filter to apply
681+ along with the text search. Defaults to None.
678682 return_fields (List[str]): The declared fields to return with search
679683 results.
680- filter_expression (Union[str, FilterExpression], optional): A filter to apply
681- along with the vector search. Defaults to None.
682684 num_results (int, optional): The top k results to return from the
683685 search. Defaults to 10.
684686 return_score (bool, optional): Whether to return the text score.
@@ -690,174 +692,82 @@ def __init__(
690692 in_order (bool): Requires the terms in the field to have
691693 the same order as the terms in the query filter, regardless of
692694 the offsets between them. Defaults to False.
693-
694- Raises:
695- TypeError: If filter_expression is not of type redisvl.query.FilterExpression
695+ params (Optional[Dict[str, Any]], optional): The parameters for the query.
696+ Defaults to None.
696697 """
698+ import nltk
699+ from nltk .corpus import stopwords
700+
701+ nltk .download ("stopwords" )
702+ self ._stopwords = set (stopwords .words ("english" ))
703+
704+ self ._text = text
697705 self ._text_field = text_field
698- self ._num_results = num_results
706+ self ._text_scorer = text_scorer
707+
699708 self .set_filter (filter_expression )
700- query_string = self ._build_query_string ()
701- from nltk .corpus import stopwords
702- import nltk
709+ self ._num_results = num_results
703710
704- nltk .download ('stopwords' )
705- self ._stopwords = set (stopwords .words ('english' ))
711+ query_string = self ._build_query_string ()
706712
707- super ().__init__ (query_string )
713+ super ().__init__ (
714+ query_string ,
715+ return_fields = return_fields ,
716+ num_results = num_results ,
717+ dialect = dialect ,
718+ sort_by = sort_by ,
719+ in_order = in_order ,
720+ params = params ,
721+ )
708722
709723 # Handle query modifiers
710- if return_fields :
711- self .return_fields (* return_fields )
712-
724+ self .scorer (self ._text_scorer )
713725 self .paging (0 , self ._num_results ).dialect (dialect )
714726
715727 if return_score :
716- self .return_fields (self .DISTANCE_ID ) #TODO
717-
718- if sort_by :
719- self .sort_by (sort_by )
720- else :
721- self .sort_by (self .DISTANCE_ID ) #TODO
728+ self .with_scores ()
722729
723- if in_order :
724- self .in_order ()
725-
726-
727- def _tokenize_query (self , user_query : str ) -> str :
730+ def tokenize_and_escape_query (self , user_query : str ) -> str :
728731 """Convert a raw user query to a redis full text query joined by ORs"""
732+ from redisvl .utils .token_escaper import TokenEscaper
729733
730- words = word_tokenize (user_query )
731-
732- tokens = [token .strip ().strip ("," ).lower () for token in user_query .split ()]
733- return " | " .join ([token for token in tokens if token not in self ._stopwords ])
734+ escaper = TokenEscaper ()
734735
736+ tokens = [
737+ escaper .escape (
738+ token .strip ().strip ("," ).replace ("“" , "" ).replace ("”" , "" ).lower ()
739+ )
740+ for token in user_query .split ()
741+ ]
742+ return " | " .join (
743+ [token for token in tokens if token and token not in self ._stopwords ]
744+ )
735745
736746 def _build_query_string (self ) -> str :
737747 """Build the full query string for text search with optional filtering."""
738748 filter_expression = self ._filter_expression
739- # TODO include text only
740749 if isinstance (filter_expression , FilterExpression ):
741750 filter_expression = str (filter_expression )
751+ else :
752+ filter_expression = ""
742753
743- text = f"(~{ Text (self ._text_field ) % self ._tokenize_query (user_query )} )"
744-
745- text_and_filter = text & self ._filter_expression
746-
747- #TODO is this method even needed? use
748- return text_and_filter
754+ text = f"(~@{ self ._text_field } :({ self .tokenize_and_escape_query (self ._text )} ))"
755+ if filter_expression and filter_expression != "*" :
756+ text += f"({ filter_expression } )"
757+ return text
749758
750- # from redisvl.utils.token_escaper import TokenEscaper
751- # escaper = TokenEscaper()
752- # def tokenize_and_escape_query(user_query: str) -> str:
753- # """Convert a raw user query to a redis full text query joined by ORs"""
754- # tokens = [escaper.escape(token.strip().strip(",").replace("“", "").replace("”", "").lower()) for token in user_query.split()]
755- # return " | ".join([token for token in tokens if token and token not in stopwords_en])
756759
757- class HybridQuery (VectorQuery , TextQuery ):
758- def __init__ ():
759- self ,
760- text : str ,
761- text_field : str ,
762- vector : Union [List [float ], bytes ],
763- vector_field_name : str ,
764- text_scorer : str = "TFIDF" ,
765- alpha : float = 0.7 ,
766- return_fields : Optional [List [str ]] = None ,
767- filter_expression : Optional [Union [str , FilterExpression ]] = None ,
768- dtype : str = "float32" ,
769- num_results : int = 10 ,
770- return_score : bool = True ,
771- dialect : int = 2 ,
772- sort_by : Optional [str ] = None ,
773- in_order : bool = False ,
760+ class HybridQuery (AggregateRequest ):
761+ def __init__ (
762+ self , text_query : TextQuery , vector_query : VectorQuery , alpha : float = 0.7
774763 ):
775- """A query for running a hybrid full text and vector search, along with
776- an optional filter expression.
764+ """An aggregate query for running a hybrid full text and vector search.
777765
778766 Args:
779- text (str): The text string to run text search with.
780- text_field (str): The name of the text field to search against.
781- vector (List[float]): The vector to perform the vector search with.
782- vector_field_name (str): The name of the vector field to search
783- against in the database.
784- text_scorer (str, optional): The text scoring algorithm to use.
785- Defaults to TFIDF.
767+ text_query (TextQuery): The text query to run text search with.
768+ vector_query (VectorQuery): The vector query to run vector search with.
786769 alpha (float, optional): The amount to weight the vector similarity
787770 score relative to the text similarity score. Defaults to 0.7
788- return_fields (List[str]): The declared fields to return with search
789- results.
790- filter_expression (Union[str, FilterExpression], optional): A filter to apply
791- along with the vector search. Defaults to None.
792- dtype (str, optional): The dtype of the vector. Defaults to
793- "float32".
794- num_results (int, optional): The top k results to return from the
795- vector search. Defaults to 10.
796- return_score (bool, optional): Whether to return the vector
797- distance. Defaults to True.
798- dialect (int, optional): The RediSearch query dialect.
799- Defaults to 2.
800- sort_by (Optional[str]): The field to order the results by. Defaults
801- to None. Results will be ordered by vector distance.
802- in_order (bool): Requires the terms in the field to have
803- the same order as the terms in the query filter, regardless of
804- the offsets between them. Defaults to False.
805-
806- Raises:
807- TypeError: If filter_expression is not of type redisvl.query.FilterExpression
808-
809- Note:
810- Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search
811- """
812- self ._text = text
813- self ._text_field_name = tex_field_name
814- self ._vector = vector
815- self ._vector_field_name = vector_field_name
816- self ._dtype = dtype
817- self ._num_results = num_results
818- self .set_filter (filter_expression )
819- query_string = self ._build_query_string ()
820-
821- # TODO how to handle multiple parents? call parent.__init__() manually?
822- super ().__init__ (query_string )
823-
824- # Handle query modifiers
825- if return_fields :
826- self .return_fields (* return_fields )
827771
828- self .paging (0 , self ._num_results ).dialect (dialect )
829-
830- if return_score :
831- self .return_fields (self .DISTANCE_ID )
832-
833- if sort_by :
834- self .sort_by (sort_by )
835- else :
836- self .sort_by (self .DISTANCE_ID )
837-
838- if in_order :
839- self .in_order ()
840-
841-
842- def _build_query_string (self ) -> str :
843- """Build the full query string for hybrid search with optional filtering."""
844- filter_expression = self ._filter_expression
845- # TODO include hybrid
846- if isinstance (filter_expression , FilterExpression ):
847- filter_expression = str (filter_expression )
848- return f"{ filter_expression } =>[KNN { self ._num_results } @{ self ._vector_field_name } ${ self .VECTOR_PARAM } AS { self .DISTANCE_ID } ]"
849-
850- @property
851- def params (self ) -> Dict [str , Any ]:
852- """Return the parameters for the query.
853-
854- Returns:
855- Dict[str, Any]: The parameters for the query.
856772 """
857- if isinstance (self ._vector , bytes ):
858- vector = self ._vector
859- else :
860- vector = array_to_buffer (self ._vector , dtype = self ._dtype )
861-
862- return {self .VECTOR_PARAM : vector }
863-
773+ pass
0 commit comments