Skip to content

Commit 26b960d

Browse files
committed
Merge branch '0.6.0' into feat/RAAE-769-add-langcache-wrapper
2 parents e545db0 + 184f521 commit 26b960d

File tree

11 files changed

+560
-117
lines changed

11 files changed

+560
-117
lines changed

docs/user_guide/08_semantic_router.ipynb

Lines changed: 214 additions & 92 deletions
Large diffs are not rendered by default.

docs/user_guide/router.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ routes:
88
metadata:
99
category: tech
1010
priority: 1
11-
distance_threshold: 1.0
11+
distance_threshold: 0.71
1212
- name: sports
1313
references:
1414
- who won the game last night?
@@ -19,7 +19,7 @@ routes:
1919
metadata:
2020
category: sports
2121
priority: 2
22-
distance_threshold: 0.5
22+
distance_threshold: 0.72
2323
- name: entertainment
2424
references:
2525
- what are the top movies right now?

redisvl/extensions/router/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def from_params(cls, name: str, vector_dims: int, dtype: str):
100100
return cls(
101101
index={"name": name, "prefix": name}, # type: ignore
102102
fields=[ # type: ignore
103+
{"name": "reference_id", "type": "tag"},
103104
{"name": "route_name", "type": "tag"},
104105
{"name": "reference", "type": "text"},
105106
{

redisvl/extensions/router/semantic.py

Lines changed: 206 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Any, Dict, List, Optional, Type
2+
from typing import Any, Dict, List, Optional, Type, Union
33

44
import redis.commands.search.reducers as reducers
55
import yaml
@@ -8,6 +8,7 @@
88
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
99
from redis.exceptions import ResponseError
1010

11+
from redisvl.exceptions import RedisModuleVersionError
1112
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
1213
from redisvl.extensions.router.schema import (
1314
DistanceAggregationMethod,
@@ -17,10 +18,17 @@
1718
SemanticRouterIndexSchema,
1819
)
1920
from redisvl.index import SearchIndex
20-
from redisvl.query import VectorRangeQuery
21+
from redisvl.query import FilterQuery, VectorRangeQuery
22+
from redisvl.query.filter import Tag
23+
from redisvl.redis.connection import RedisConnectionFactory
2124
from redisvl.redis.utils import convert_bytes, make_dict
2225
from redisvl.utils.log import get_logger
23-
from redisvl.utils.utils import deprecated_argument, hashify, model_to_dict
26+
from redisvl.utils.utils import (
27+
deprecated_argument,
28+
hashify,
29+
model_to_dict,
30+
scan_by_pattern,
31+
)
2432
from redisvl.utils.vectorize.base import BaseVectorizer
2533
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
2634

@@ -98,9 +106,41 @@ def __init__(
98106
routes=routes,
99107
vectorizer=vectorizer,
100108
routing_config=routing_config,
109+
redis_url=redis_url,
110+
redis_client=redis_client,
101111
)
112+
102113
self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)
103114

115+
self._index.client.json().set(f"{self.name}:route_config", f".", self.to_dict()) # type: ignore
116+
117+
@classmethod
118+
def from_existing(
119+
cls,
120+
name: str,
121+
redis_client: Optional[Redis] = None,
122+
redis_url: str = "redis://localhost:6379",
123+
**kwargs,
124+
) -> "SemanticRouter":
125+
"""Return SemanticRouter instance from existing index."""
126+
try:
127+
if redis_url:
128+
redis_client = RedisConnectionFactory.get_redis_connection(
129+
redis_url=redis_url,
130+
**kwargs,
131+
)
132+
elif redis_client:
133+
RedisConnectionFactory.validate_sync_redis(redis_client)
134+
except RedisModuleVersionError as e:
135+
raise RedisModuleVersionError(
136+
f"Loading from existing index failed. {str(e)}"
137+
)
138+
139+
router_dict = redis_client.json().get(f"{name}:route_config") # type: ignore
140+
return cls.from_dict(
141+
router_dict, redis_url=redis_url, redis_client=redis_client
142+
)
143+
104144
@deprecated_argument("dtype")
105145
def _initialize_index(
106146
self,
@@ -111,9 +151,11 @@ def _initialize_index(
111151
**connection_kwargs,
112152
):
113153
"""Initialize the search index and handle Redis connection."""
154+
114155
schema = SemanticRouterIndexSchema.from_params(
115156
self.name, self.vectorizer.dims, self.vectorizer.dtype # type: ignore
116157
)
158+
117159
self._index = SearchIndex(
118160
schema=schema,
119161
redis_client=redis_client,
@@ -174,10 +216,10 @@ def update_route_thresholds(self, route_thresholds: Dict[str, Optional[float]]):
174216
if route.name in route_thresholds:
175217
route.distance_threshold = route_thresholds[route.name] # type: ignore
176218

177-
def _route_ref_key(self, route_name: str, reference: str) -> str:
219+
@staticmethod
220+
def _route_ref_key(index: SearchIndex, route_name: str, reference_hash: str) -> str:
178221
"""Generate the route reference key."""
179-
reference_hash = hashify(reference)
180-
return f"{self._index.prefix}:{route_name}:{reference_hash}"
222+
return f"{index.prefix}:{route_name}:{reference_hash}"
181223

182224
def _add_routes(self, routes: List[Route]):
183225
"""Add routes to the router and index.
@@ -195,14 +237,18 @@ def _add_routes(self, routes: List[Route]):
195237
)
196238
# set route references
197239
for i, reference in enumerate(route.references):
240+
reference_hash = hashify(reference)
198241
route_references.append(
199242
{
243+
"reference_id": reference_hash,
200244
"route_name": route.name,
201245
"reference": reference,
202246
"vector": reference_vectors[i],
203247
}
204248
)
205-
keys.append(self._route_ref_key(route.name, reference))
249+
keys.append(
250+
self._route_ref_key(self._index, route.name, reference_hash)
251+
)
206252

207253
# set route if does not yet exist client side
208254
if not self.get(route.name):
@@ -438,7 +484,7 @@ def remove_route(self, route_name: str) -> None:
438484
else:
439485
self._index.drop_keys(
440486
[
441-
self._route_ref_key(route.name, reference)
487+
self._route_ref_key(self._index, route.name, hashify(reference))
442488
for reference in route.references
443489
]
444490
)
@@ -596,3 +642,155 @@ def to_yaml(self, file_path: str, overwrite: bool = True) -> None:
596642
with open(fp, "w") as f:
597643
yaml_data = self.to_dict()
598644
yaml.dump(yaml_data, f, sort_keys=False)
645+
646+
# reference methods
647+
def add_route_references(
648+
self,
649+
route_name: str,
650+
references: Union[str, List[str]],
651+
) -> List[str]:
652+
"""Add a reference(s) to an existing route.
653+
654+
Args:
655+
router_name (str): The name of the router.
656+
references (Union[str, List[str]]): The reference or list of references to add.
657+
658+
Returns:
659+
List[str]: The list of added references keys.
660+
"""
661+
662+
if isinstance(references, str):
663+
references = [references]
664+
665+
route_references: List[Dict[str, Any]] = []
666+
keys: List[str] = []
667+
668+
# embed route references as a single batch
669+
reference_vectors = self.vectorizer.embed_many(references, as_buffer=True)
670+
671+
# set route references
672+
for i, reference in enumerate(references):
673+
reference_hash = hashify(reference)
674+
675+
route_references.append(
676+
{
677+
"reference_id": reference_hash,
678+
"route_name": route_name,
679+
"reference": reference,
680+
"vector": reference_vectors[i],
681+
}
682+
)
683+
keys.append(self._route_ref_key(self._index, route_name, reference_hash))
684+
685+
keys = self._index.load(route_references, keys=keys)
686+
687+
route = self.get(route_name)
688+
if not route:
689+
raise ValueError(f"Route {route_name} not found in the SemanticRouter")
690+
route.references.extend(references)
691+
self._update_router_state()
692+
return keys
693+
694+
@staticmethod
695+
def _make_filter_queries(ids: List[str]) -> List[FilterQuery]:
696+
"""Create a filter query for the given ids."""
697+
698+
queries = []
699+
700+
for id in ids:
701+
fe = Tag("reference_id") == id
702+
fq = FilterQuery(
703+
return_fields=["reference_id", "route_name", "reference"],
704+
filter_expression=fe,
705+
)
706+
queries.append(fq)
707+
708+
return queries
709+
710+
def get_route_references(
711+
self,
712+
route_name: str = "",
713+
reference_ids: List[str] = [],
714+
keys: List[str] = [],
715+
) -> List[Dict[str, Any]]:
716+
"""Get references for an existing route route.
717+
718+
Args:
719+
router_name (str): The name of the router.
720+
references (Union[str, List[str]]): The reference or list of references to add.
721+
722+
Returns:
723+
List[Dict[str, Any]]]: Reference objects stored
724+
"""
725+
726+
if reference_ids:
727+
queries = self._make_filter_queries(reference_ids)
728+
elif route_name:
729+
if not keys:
730+
keys = scan_by_pattern(
731+
self._index.client, f"{self._index.prefix}:{route_name}:*" # type: ignore
732+
)
733+
734+
queries = self._make_filter_queries(
735+
[key.split(":")[-1] for key in convert_bytes(keys)]
736+
)
737+
else:
738+
raise ValueError(
739+
"Must provide a route name, reference ids, or keys to get references"
740+
)
741+
742+
res = self._index.batch_query(queries)
743+
744+
return [r[0] for r in res if len(r) > 0]
745+
746+
def delete_route_references(
747+
self,
748+
route_name: str = "",
749+
reference_ids: List[str] = [],
750+
keys: List[str] = [],
751+
) -> int:
752+
"""Get references for an existing semantic router route.
753+
754+
Args:
755+
router_name Optional(str): The name of the router.
756+
reference_ids Optional(List[str]]): The reference or list of references to delete.
757+
keys Optional(List[str]]): List of fully qualified keys (prefix:router:reference_id) to delete.
758+
759+
Returns:
760+
int: Number of objects deleted
761+
"""
762+
763+
if reference_ids and not keys:
764+
queries = self._make_filter_queries(reference_ids)
765+
res = self._index.batch_query(queries)
766+
keys = [r[0]["id"] for r in res if len(r) > 0]
767+
elif not keys:
768+
keys = scan_by_pattern(
769+
self._index.client, f"{self._index.prefix}:{route_name}:*" # type: ignore
770+
)
771+
772+
if not keys:
773+
raise ValueError(f"No references found for route {route_name}")
774+
775+
to_be_deleted = []
776+
for key in keys:
777+
route_name = key.split(":")[-2]
778+
to_be_deleted.append(
779+
(route_name, convert_bytes(self._index.client.hgetall(key))) # type: ignore
780+
)
781+
782+
deleted = self._index.drop_keys(keys)
783+
784+
for route_name, delete in to_be_deleted:
785+
route = self.get(route_name)
786+
if not route:
787+
raise ValueError(f"Route {route_name} not found in the SemanticRouter")
788+
route.references.remove(delete["reference"])
789+
790+
self._update_router_state()
791+
792+
return deleted
793+
794+
def _update_router_state(self) -> None:
795+
"""Update the router configuration in Redis."""
796+
self._index.client.json().set(f"{self.name}:route_config", f".", self.to_dict()) # type: ignore

redisvl/index/index.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Iterable,
1515
List,
1616
Optional,
17+
Sequence,
1718
Tuple,
1819
Union,
1920
)
@@ -833,7 +834,7 @@ def search(self, *args, **kwargs) -> "Result":
833834
raise RedisSearchError(f"Error while searching: {str(e)}") from e
834835

835836
def batch_query(
836-
self, queries: List[BaseQuery], batch_size: int = 10
837+
self, queries: Sequence[BaseQuery], batch_size: int = 10
837838
) -> List[List[Dict[str, Any]]]:
838839
"""Execute a batch of queries and process results."""
839840
results = self.batch_search(

redisvl/utils/optimize/router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) -
1818
run_dict[td.id] = {}
1919
route_match = router(td.query)
2020
if route_match and route_match.name == td.query_match:
21-
run_dict[td.id][td.query_match] = 1
21+
run_dict[td.id][td.query_match] = np.int64(1)
2222
else:
23-
run_dict[td.id][NULL_RESPONSE_KEY] = 1
23+
run_dict[td.id][NULL_RESPONSE_KEY] = np.int64(1)
2424

2525
return Run(run_dict)
2626

redisvl/utils/optimize/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List
22

3+
import numpy as np
34
from ranx import Qrels
45

56
from redisvl.utils.optimize.schema import LabeledData
@@ -13,10 +14,10 @@ def _format_qrels(test_data: List[LabeledData]) -> Qrels:
1314

1415
for td in test_data:
1516
if td.query_match:
16-
qrels_dict[td.id] = {td.query_match: 1}
17+
qrels_dict[td.id] = {td.query_match: np.int64(1)}
1718
else:
1819
# This is for capturing true negatives from test set
19-
qrels_dict[td.id] = {NULL_RESPONSE_KEY: 1}
20+
qrels_dict[td.id] = {NULL_RESPONSE_KEY: np.int64(1)}
2021

2122
return Qrels(qrels_dict)
2223

0 commit comments

Comments
 (0)