@@ -91,7 +91,8 @@ def __init__(
9191 num_results (Optional[int], optional): The number of results to return. Defaults to 10.
9292 dialect (int, optional): The query dialect. Defaults to 2.
9393 sort_by (Optional[str], optional): The field to order the results by. Defaults to None.
94- in_order (bool, optional): Requires the terms in the field to have the same order as the terms in the query filter. Defaults to False.
94+ in_order (bool, optional): Requires the terms in the field to have the same order as the
95+ terms in the query filter. Defaults to False.
9596 params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
9697
9798 Raises:
@@ -136,7 +137,8 @@ def __init__(
136137 """A query for a simple count operation provided some filter expression.
137138
138139 Args:
139- filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to query with. Defaults to None.
140+ filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to
141+ query with. Defaults to None.
140142 params (Optional[Dict[str, Any]], optional): The parameters for the query. Defaults to None.
141143
142144 Raises:
@@ -214,6 +216,7 @@ def __init__(
214216 "float32".
215217 num_results (int, optional): The top k results to return from the
216218 vector search. Defaults to 10.
219+
217220 return_score (bool, optional): Whether to return the vector
218221 distance. Defaults to True.
219222 dialect (int, optional): The RediSearch query dialect.
@@ -647,3 +650,208 @@ def params(self) -> Dict[str, Any]:
647650class RangeQuery (VectorRangeQuery ):
648651 # keep for backwards compatibility
649652 pass
653+
654+
655+ class TextQuery (FilterQuery ):
656+ def __init__ (
657+ self ,
658+ text : str ,
659+ text_field : str ,
660+ text_scorer : str = "TFIDF" ,
661+ return_fields : Optional [List [str ]] = None ,
662+ filter_expression : Optional [Union [str , FilterExpression ]] = None ,
663+ num_results : int = 10 ,
664+ return_score : bool = True ,
665+ dialect : int = 2 ,
666+ sort_by : Optional [str ] = None ,
667+ in_order : bool = False ,
668+ ):
669+ """A query for running a full text and vector search, along with an optional
670+ filter expression.
671+
672+ Args:
673+ text (str): The text string to perform the text search with.
674+ text_field (str): The name of the document field to perform text search on.
675+ text_scorer (str, optional): The text scoring algorithm to use.
676+ Defaults to TFIDF. Options are {TFIDF, BM25, DOCNORM, DISMAX, DOCSCORE}.
677+ See https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/scoring/
678+ return_fields (List[str]): The declared fields to return with search
679+ results.
680+ filter_expression (Union[str, FilterExpression], optional): A filter to apply
681+ along with the vector search. Defaults to None.
682+ num_results (int, optional): The top k results to return from the
683+ search. Defaults to 10.
684+ return_score (bool, optional): Whether to return the text score.
685+ Defaults to True.
686+ dialect (int, optional): The RediSearch query dialect.
687+ Defaults to 2.
688+ sort_by (Optional[str]): The field to order the results by. Defaults
689+ to None. Results will be ordered by text score.
690+ in_order (bool): Requires the terms in the field to have
691+ the same order as the terms in the query filter, regardless of
692+ the offsets between them. Defaults to False.
693+
694+ Raises:
695+ TypeError: If filter_expression is not of type redisvl.query.FilterExpression
696+ """
697+ self ._text_field = text_field
698+ self ._num_results = num_results
699+ self .set_filter (filter_expression )
700+ query_string = self ._build_query_string ()
701+ from nltk .corpus import stopwords
702+ import nltk
703+
704+ nltk .download ('stopwords' )
705+ self ._stopwords = set (stopwords .words ('english' ))
706+
707+
708+ super ().__init__ (query_string )
709+
710+ # Handle query modifiers
711+ if return_fields :
712+ self .return_fields (* return_fields )
713+
714+ self .paging (0 , self ._num_results ).dialect (dialect )
715+
716+ if return_score :
717+ self .return_fields (self .DISTANCE_ID ) #TODO
718+
719+ if sort_by :
720+ self .sort_by (sort_by )
721+ else :
722+ self .sort_by (self .DISTANCE_ID ) #TODO
723+
724+ if in_order :
725+ self .in_order ()
726+
727+
728+ def _tokenize_query (self , user_query : str ) -> str :
729+ """Convert a raw user query to a redis full text query joined by ORs"""
730+ words = word_tokenize (user_query )
731+
732+ tokens = [token .strip ().strip ("," ).lower () for token in user_query .split ()]
733+ return " | " .join ([token for token in tokens if token not in self ._stopwords ])
734+
735+
736+ def _build_query_string (self ) -> str :
737+ """Build the full query string for text search with optional filtering."""
738+ filter_expression = self ._filter_expression
739+ # TODO include text only
740+ if isinstance (filter_expression , FilterExpression ):
741+ filter_expression = str (filter_expression )
742+
743+ text = f"(~{ Text (self ._text_field ) % self ._tokenize_query (user_query )} )"
744+
745+ text_and_filter = text & self ._filter_expression
746+
747+ #TODO is this method even needed? use
748+ return text_and_filter
749+
750+
751+ class HybridQuery (VectorQuery , TextQuery ):
752+ def __init__ ():
753+ self ,
754+ text : str ,
755+ text_field : str ,
756+ vector : Union [List [float ], bytes ],
757+ vector_field_name : str ,
758+ text_scorer : str = "TFIDF" ,
759+ alpha : float = 0.7 ,
760+ return_fields : Optional [List [str ]] = None ,
761+ filter_expression : Optional [Union [str , FilterExpression ]] = None ,
762+ dtype : str = "float32" ,
763+ num_results : int = 10 ,
764+ return_score : bool = True ,
765+ dialect : int = 2 ,
766+ sort_by : Optional [str ] = None ,
767+ in_order : bool = False ,
768+ ):
769+ """A query for running a hybrid full text and vector search, along with
770+ an optional filter expression.
771+
772+ Args:
773+ text (str): The text string to run text search with.
774+ text_field (str): The name of the text field to search against.
775+ vector (List[float]): The vector to perform the vector search with.
776+ vector_field_name (str): The name of the vector field to search
777+ against in the database.
778+ text_scorer (str, optional): The text scoring algorithm to use.
779+ Defaults to TFIDF.
780+ alpha (float, optional): The amount to weight the vector similarity
781+ score relative to the text similarity score. Defaults to 0.7
782+ return_fields (List[str]): The declared fields to return with search
783+ results.
784+ filter_expression (Union[str, FilterExpression], optional): A filter to apply
785+ along with the vector search. Defaults to None.
786+ dtype (str, optional): The dtype of the vector. Defaults to
787+ "float32".
788+ num_results (int, optional): The top k results to return from the
789+ vector search. Defaults to 10.
790+ return_score (bool, optional): Whether to return the vector
791+ distance. Defaults to True.
792+ dialect (int, optional): The RediSearch query dialect.
793+ Defaults to 2.
794+ sort_by (Optional[str]): The field to order the results by. Defaults
795+ to None. Results will be ordered by vector distance.
796+ in_order (bool): Requires the terms in the field to have
797+ the same order as the terms in the query filter, regardless of
798+ the offsets between them. Defaults to False.
799+
800+ Raises:
801+ TypeError: If filter_expression is not of type redisvl.query.FilterExpression
802+
803+ Note:
804+ Learn more about vector queries in Redis: https://redis.io/docs/interact/search-and-query/search/vectors/#knn-search
805+ """
806+ self ._text = text
807+ self ._text_field_name = tex_field_name
808+ self ._vector = vector
809+ self ._vector_field_name = vector_field_name
810+ self ._dtype = dtype
811+ self ._num_results = num_results
812+ self .set_filter (filter_expression )
813+ query_string = self ._build_query_string ()
814+
815+ # TODO how to handle multiple parents? call parent.__init__() manually?
816+ super ().__init__ (query_string )
817+
818+ # Handle query modifiers
819+ if return_fields :
820+ self .return_fields (* return_fields )
821+
822+ self .paging (0 , self ._num_results ).dialect (dialect )
823+
824+ if return_score :
825+ self .return_fields (self .DISTANCE_ID )
826+
827+ if sort_by :
828+ self .sort_by (sort_by )
829+ else :
830+ self .sort_by (self .DISTANCE_ID )
831+
832+ if in_order :
833+ self .in_order ()
834+
835+
836+ def _build_query_string (self ) -> str :
837+ """Build the full query string for hybrid search with optional filtering."""
838+ filter_expression = self ._filter_expression
839+ # TODO include hybrid
840+ if isinstance (filter_expression , FilterExpression ):
841+ filter_expression = str (filter_expression )
842+ return f"{ filter_expression } =>[KNN { self ._num_results } @{ self ._vector_field_name } ${ self .VECTOR_PARAM } AS { self .DISTANCE_ID } ]"
843+
844+ @property
845+ def params (self ) -> Dict [str , Any ]:
846+ """Return the parameters for the query.
847+
848+ Returns:
849+ Dict[str, Any]: The parameters for the query.
850+ """
851+ if isinstance (self ._vector , bytes ):
852+ vector = self ._vector
853+ else :
854+ vector = array_to_buffer (self ._vector , dtype = self ._dtype )
855+
856+ return {self .VECTOR_PARAM : vector }
857+
0 commit comments