11from pathlib import Path
2- from typing import Any , Dict , List , Optional , Type
2+ from typing import Any , Dict , List , Optional , Type , Union
33
44import redis .commands .search .reducers as reducers
55import yaml
88from redis .commands .search .aggregation import AggregateRequest , AggregateResult , Reducer
99from redis .exceptions import ResponseError
1010
11+ from redisvl .exceptions import RedisModuleVersionError
1112from redisvl .extensions .constants import ROUTE_VECTOR_FIELD_NAME
1213from redisvl .extensions .router .schema import (
1314 DistanceAggregationMethod ,
1718 SemanticRouterIndexSchema ,
1819)
1920from redisvl .index import SearchIndex
20- from redisvl .query import VectorRangeQuery
21+ from redisvl .query import FilterQuery , VectorRangeQuery
22+ from redisvl .query .filter import Tag
23+ from redisvl .redis .connection import RedisConnectionFactory
2124from redisvl .redis .utils import convert_bytes , make_dict
2225from redisvl .utils .log import get_logger
23- from redisvl .utils .utils import deprecated_argument , hashify , model_to_dict
26+ from redisvl .utils .utils import (
27+ deprecated_argument ,
28+ hashify ,
29+ model_to_dict ,
30+ scan_by_pattern ,
31+ )
2432from redisvl .utils .vectorize .base import BaseVectorizer
2533from redisvl .utils .vectorize .text .huggingface import HFTextVectorizer
2634
@@ -98,9 +106,41 @@ def __init__(
98106 routes = routes ,
99107 vectorizer = vectorizer ,
100108 routing_config = routing_config ,
109+ redis_url = redis_url ,
110+ redis_client = redis_client ,
101111 )
112+
102113 self ._initialize_index (redis_client , redis_url , overwrite , ** connection_kwargs )
103114
115+ self ._index .client .json ().set (f"{ self .name } :route_config" , f"." , self .to_dict ()) # type: ignore
116+
117+ @classmethod
118+ def from_existing (
119+ cls ,
120+ name : str ,
121+ redis_client : Optional [Redis ] = None ,
122+ redis_url : str = "redis://localhost:6379" ,
123+ ** kwargs ,
124+ ) -> "SemanticRouter" :
125+ """Return SemanticRouter instance from existing index."""
126+ try :
127+ if redis_url :
128+ redis_client = RedisConnectionFactory .get_redis_connection (
129+ redis_url = redis_url ,
130+ ** kwargs ,
131+ )
132+ elif redis_client :
133+ RedisConnectionFactory .validate_sync_redis (redis_client )
134+ except RedisModuleVersionError as e :
135+ raise RedisModuleVersionError (
136+ f"Loading from existing index failed. { str (e )} "
137+ )
138+
139+ router_dict = redis_client .json ().get (f"{ name } :route_config" ) # type: ignore
140+ return cls .from_dict (
141+ router_dict , redis_url = redis_url , redis_client = redis_client
142+ )
143+
104144 @deprecated_argument ("dtype" )
105145 def _initialize_index (
106146 self ,
@@ -111,9 +151,11 @@ def _initialize_index(
111151 ** connection_kwargs ,
112152 ):
113153 """Initialize the search index and handle Redis connection."""
154+
114155 schema = SemanticRouterIndexSchema .from_params (
115156 self .name , self .vectorizer .dims , self .vectorizer .dtype # type: ignore
116157 )
158+
117159 self ._index = SearchIndex (
118160 schema = schema ,
119161 redis_client = redis_client ,
@@ -174,10 +216,10 @@ def update_route_thresholds(self, route_thresholds: Dict[str, Optional[float]]):
174216 if route .name in route_thresholds :
175217 route .distance_threshold = route_thresholds [route .name ] # type: ignore
176218
177- def _route_ref_key (self , route_name : str , reference : str ) -> str :
219+ @staticmethod
220+ def _route_ref_key (index : SearchIndex , route_name : str , reference_hash : str ) -> str :
178221 """Generate the route reference key."""
179- reference_hash = hashify (reference )
180- return f"{ self ._index .prefix } :{ route_name } :{ reference_hash } "
222+ return f"{ index .prefix } :{ route_name } :{ reference_hash } "
181223
182224 def _add_routes (self , routes : List [Route ]):
183225 """Add routes to the router and index.
@@ -195,14 +237,18 @@ def _add_routes(self, routes: List[Route]):
195237 )
196238 # set route references
197239 for i , reference in enumerate (route .references ):
240+ reference_hash = hashify (reference )
198241 route_references .append (
199242 {
243+ "reference_id" : reference_hash ,
200244 "route_name" : route .name ,
201245 "reference" : reference ,
202246 "vector" : reference_vectors [i ],
203247 }
204248 )
205- keys .append (self ._route_ref_key (route .name , reference ))
249+ keys .append (
250+ self ._route_ref_key (self ._index , route .name , reference_hash )
251+ )
206252
207253 # set route if does not yet exist client side
208254 if not self .get (route .name ):
@@ -438,7 +484,7 @@ def remove_route(self, route_name: str) -> None:
438484 else :
439485 self ._index .drop_keys (
440486 [
441- self ._route_ref_key (route .name , reference )
487+ self ._route_ref_key (self . _index , route .name , hashify ( reference ) )
442488 for reference in route .references
443489 ]
444490 )
@@ -596,3 +642,155 @@ def to_yaml(self, file_path: str, overwrite: bool = True) -> None:
596642 with open (fp , "w" ) as f :
597643 yaml_data = self .to_dict ()
598644 yaml .dump (yaml_data , f , sort_keys = False )
645+
646+ # reference methods
647+ def add_route_references (
648+ self ,
649+ route_name : str ,
650+ references : Union [str , List [str ]],
651+ ) -> List [str ]:
652+ """Add a reference(s) to an existing route.
653+
654+ Args:
655+ router_name (str): The name of the router.
656+ references (Union[str, List[str]]): The reference or list of references to add.
657+
658+ Returns:
659+ List[str]: The list of added references keys.
660+ """
661+
662+ if isinstance (references , str ):
663+ references = [references ]
664+
665+ route_references : List [Dict [str , Any ]] = []
666+ keys : List [str ] = []
667+
668+ # embed route references as a single batch
669+ reference_vectors = self .vectorizer .embed_many (references , as_buffer = True )
670+
671+ # set route references
672+ for i , reference in enumerate (references ):
673+ reference_hash = hashify (reference )
674+
675+ route_references .append (
676+ {
677+ "reference_id" : reference_hash ,
678+ "route_name" : route_name ,
679+ "reference" : reference ,
680+ "vector" : reference_vectors [i ],
681+ }
682+ )
683+ keys .append (self ._route_ref_key (self ._index , route_name , reference_hash ))
684+
685+ keys = self ._index .load (route_references , keys = keys )
686+
687+ route = self .get (route_name )
688+ if not route :
689+ raise ValueError (f"Route { route_name } not found in the SemanticRouter" )
690+ route .references .extend (references )
691+ self ._update_router_state ()
692+ return keys
693+
694+ @staticmethod
695+ def _make_filter_queries (ids : List [str ]) -> List [FilterQuery ]:
696+ """Create a filter query for the given ids."""
697+
698+ queries = []
699+
700+ for id in ids :
701+ fe = Tag ("reference_id" ) == id
702+ fq = FilterQuery (
703+ return_fields = ["reference_id" , "route_name" , "reference" ],
704+ filter_expression = fe ,
705+ )
706+ queries .append (fq )
707+
708+ return queries
709+
710+ def get_route_references (
711+ self ,
712+ route_name : str = "" ,
713+ reference_ids : List [str ] = [],
714+ keys : List [str ] = [],
715+ ) -> List [Dict [str , Any ]]:
716+ """Get references for an existing route route.
717+
718+ Args:
719+ router_name (str): The name of the router.
720+ references (Union[str, List[str]]): The reference or list of references to add.
721+
722+ Returns:
723+ List[Dict[str, Any]]]: Reference objects stored
724+ """
725+
726+ if reference_ids :
727+ queries = self ._make_filter_queries (reference_ids )
728+ elif route_name :
729+ if not keys :
730+ keys = scan_by_pattern (
731+ self ._index .client , f"{ self ._index .prefix } :{ route_name } :*" # type: ignore
732+ )
733+
734+ queries = self ._make_filter_queries (
735+ [key .split (":" )[- 1 ] for key in convert_bytes (keys )]
736+ )
737+ else :
738+ raise ValueError (
739+ "Must provide a route name, reference ids, or keys to get references"
740+ )
741+
742+ res = self ._index .batch_query (queries )
743+
744+ return [r [0 ] for r in res if len (r ) > 0 ]
745+
746+ def delete_route_references (
747+ self ,
748+ route_name : str = "" ,
749+ reference_ids : List [str ] = [],
750+ keys : List [str ] = [],
751+ ) -> int :
752+ """Get references for an existing semantic router route.
753+
754+ Args:
755+ router_name Optional(str): The name of the router.
756+ reference_ids Optional(List[str]]): The reference or list of references to delete.
757+ keys Optional(List[str]]): List of fully qualified keys (prefix:router:reference_id) to delete.
758+
759+ Returns:
760+ int: Number of objects deleted
761+ """
762+
763+ if reference_ids and not keys :
764+ queries = self ._make_filter_queries (reference_ids )
765+ res = self ._index .batch_query (queries )
766+ keys = [r [0 ]["id" ] for r in res if len (r ) > 0 ]
767+ elif not keys :
768+ keys = scan_by_pattern (
769+ self ._index .client , f"{ self ._index .prefix } :{ route_name } :*" # type: ignore
770+ )
771+
772+ if not keys :
773+ raise ValueError (f"No references found for route { route_name } " )
774+
775+ to_be_deleted = []
776+ for key in keys :
777+ route_name = key .split (":" )[- 2 ]
778+ to_be_deleted .append (
779+ (route_name , convert_bytes (self ._index .client .hgetall (key ))) # type: ignore
780+ )
781+
782+ deleted = self ._index .drop_keys (keys )
783+
784+ for route_name , delete in to_be_deleted :
785+ route = self .get (route_name )
786+ if not route :
787+ raise ValueError (f"Route { route_name } not found in the SemanticRouter" )
788+ route .references .remove (delete ["reference" ])
789+
790+ self ._update_router_state ()
791+
792+ return deleted
793+
794+ def _update_router_state (self ) -> None :
795+ """Update the router configuration in Redis."""
796+ self ._index .client .json ().set (f"{ self .name } :route_config" , f"." , self .to_dict ()) # type: ignore
0 commit comments