Skip to content

Commit 18ff100

Browse files
committed
updates for notebook
1 parent 1a29c1e commit 18ff100

File tree

14 files changed

+90
-62
lines changed

14 files changed

+90
-62
lines changed

.github/workflows/test.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ jobs:
105105
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
106106
run: |
107107
docker run -d --name redis -p 6379:6379 redis/redis-stack-server:latest
108-
make test-notebooks
108+
if [[ "${{ matrix.python-version }}" > "3.9" ]]; then
109+
make test-notebooks
110+
else
111+
poetry run test-notebooks --ignore ./docs/user_guide/09_threshold_optimization.ipynb
112+
fi
109113
110114
docs:
111115
runs-on: ubuntu-latest

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ embeddings = co.embed_many(
225225
### Threshold Optimization
226226
[Optimize distance thresholds for cache and router](https://docs.redisvl.com/en/stable/user_guide/09_threshold_optimization.html) with the utility `ThresholdOptimizer` classes.
227227

228+
**Note:** only available for `python > 3.9`.
229+
228230

229231

230232
## 💫 Extensions

docs/user_guide/09_threshold_optimization.ipynb

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
"source": [
77
"# Threshold Optimization\n",
88
"\n",
9-
"After setting up `SemanticRouter` or `SemanticCache` it best to tune the `distance_threshold` to get the best performance out of your system. RedisVL provides helper classes to make this light weight optimization easy.\n",
9+
"After setting up `SemanticRouter` or `SemanticCache` it's best to tune the `distance_threshold` to get the most performance out of your system. RedisVL provides helper classes to make this light weight optimization easy.\n",
10+
"\n",
11+
"> **Note:** Threshold optimization relies on `python > 3.9.`\n",
1012
"\n",
1113
"# CacheThresholdOptimizer\n",
1214
"\n",
@@ -53,8 +55,8 @@
5355
" 'prompt': 'what is the capital of france?',\n",
5456
" 'response': 'paris',\n",
5557
" 'vector_distance': 0.421104669571,\n",
56-
" 'inserted_at': 1741033054.9,\n",
57-
" 'updated_at': 1741033054.9,\n",
58+
" 'inserted_at': 1741039231.99,\n",
59+
" 'updated_at': 1741039231.99,\n",
5860
" 'key': 'sem_cache:c990cc06e5e77570e5f03360426d2b7f947cbb5a67daa8af8164bfe0b3e24fe3'}]"
5961
]
6062
},
@@ -112,7 +114,7 @@
112114
}
113115
],
114116
"source": [
115-
"from redisvl.utils.threshold_optimizer.cache import CacheThresholdOptimizer\n",
117+
"from redisvl.utils.optimize import CacheThresholdOptimizer\n",
116118
"\n",
117119
"test_data = [\n",
118120
" {\n",
@@ -181,8 +183,8 @@
181183
" 'prompt': 'what is the capital of france?',\n",
182184
" 'response': 'paris',\n",
183185
" 'vector_distance': 0.0835866332054,\n",
184-
" 'inserted_at': 1741033054.9,\n",
185-
" 'updated_at': 1741033054.9,\n",
186+
" 'inserted_at': 1741039231.99,\n",
187+
" 'updated_at': 1741039231.99,\n",
186188
" 'key': 'sem_cache:c990cc06e5e77570e5f03360426d2b7f947cbb5a67daa8af8164bfe0b3e24fe3'}]"
187189
]
188190
},
@@ -331,13 +333,13 @@
331333
"text": [
332334
"Route thresholds before: {'greeting': 0.5, 'farewell': 0.5} \n",
333335
"\n",
334-
"Eval metric F1: start 0.438, end 0.562 \n",
335-
"Ending thresholds: {'greeting': 0.09239303792843903, 'farewell': 0.6535353535353534}\n"
336+
"Eval metric F1: start 0.438, end 0.719 \n",
337+
"Ending thresholds: {'greeting': 1.0858585858585856, 'farewell': 0.5545454545454545}\n"
336338
]
337339
}
338340
],
339341
"source": [
340-
"from redisvl.utils.threshold_optimizer.router import RouterThresholdOptimizer\n",
342+
"from redisvl.utils.optimize import RouterThresholdOptimizer\n",
341343
"\n",
342344
"print(f\"Route thresholds before: {router.route_thresholds} \\n\")\n",
343345
"optimizer = RouterThresholdOptimizer(router, test_data)\n",

docs/user_guide/router.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ routes:
77
- what's trending in tech?
88
metadata:
99
category: tech
10-
priority: '1'
10+
priority: 1
1111
distance_threshold: 1.0
1212
- name: sports
1313
references:
@@ -18,7 +18,7 @@ routes:
1818
- basketball and football
1919
metadata:
2020
category: sports
21-
priority: '2'
21+
priority: 2
2222
distance_threshold: 0.5
2323
- name: entertainment
2424
references:
@@ -27,12 +27,11 @@ routes:
2727
- what's new in the entertainment industry?
2828
metadata:
2929
category: entertainment
30-
priority: '3'
30+
priority: 3
3131
distance_threshold: 0.7
3232
vectorizer:
3333
type: hf
3434
model: sentence-transformers/all-mpnet-base-v2
3535
routing_config:
36-
distance_threshold: 0.5
3736
max_k: 3
3837
aggregation_method: min

redisvl/utils/optimize/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric
2+
from redisvl.utils.optimize.cache import CacheThresholdOptimizer
3+
from redisvl.utils.optimize.router import RouterThresholdOptimizer
4+
from redisvl.utils.optimize.schema import TestData
5+
6+
__all__ = [
7+
"CacheThresholdOptimizer",
8+
"RouterThresholdOptimizer",
9+
"EvalMetric",
10+
"BaseThresholdOptimizer",
11+
"TestData",
12+
]
Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,16 @@
22
from enum import Enum
33
from typing import Any, Callable, Dict, List, TypeVar
44

5-
from redisvl.utils.threshold_optimizer.utils import _validate_test_dict
5+
from redisvl.utils.optimize.utils import _validate_test_dict
66

77

8-
class EvalMetric(Enum):
8+
class EvalMetric(str, Enum):
99
"""Evaluation metrics for threshold optimization."""
1010

1111
F1 = "f1"
1212
PRECISION = "precision"
1313
RECALL = "recall"
1414

15-
def __str__(self) -> str:
16-
return self.value
17-
18-
@classmethod
19-
def from_string(cls, metric: str) -> "EvalMetric":
20-
"""Convert string to EvalMetric enum."""
21-
try:
22-
return cls(metric.lower())
23-
except ValueError:
24-
raise ValueError(
25-
f"Invalid metric: {metric}. Valid options are: {', '.join(m.value for m in cls)}"
26-
)
27-
2815

2916
T = TypeVar("T") # Type variable for the optimizable object (Cache or Router)
3017

redisvl/utils/threshold_optimizer/cache.py renamed to redisvl/utils/optimize/cache.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,26 @@
55

66
from redisvl.extensions.llmcache.semantic import SemanticCache
77
from redisvl.query import RangeQuery
8-
from redisvl.utils.threshold_optimizer.base import BaseThresholdOptimizer, EvalMetric
9-
from redisvl.utils.threshold_optimizer.schema import TestData
10-
from redisvl.utils.threshold_optimizer.utils import NULL_RESPONSE_KEY, _format_qrels
8+
from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric
9+
from redisvl.utils.optimize.schema import TestData
10+
from redisvl.utils.optimize.utils import NULL_RESPONSE_KEY, _format_qrels
1111

1212

1313
def _generate_run_cache(test_data: List[TestData], threshold: float) -> Run:
1414
"""Format observed data for evaluation with ranx"""
1515
run_dict: Dict[str, Dict[str, int]] = {}
1616

1717
for td in test_data:
18-
run_dict[td.q_id] = {}
18+
run_dict[td.id] = {}
1919
for res in td.response:
2020
if float(res["vector_distance"]) < threshold:
2121
# value of 1 is irrelevant checks only on match for f1
22-
run_dict[td.q_id][res["id"]] = 1
22+
run_dict[td.id][res["id"]] = 1
2323

24-
if not run_dict[td.q_id]:
24+
if not run_dict[td.id]:
2525
# ranx is a little odd in that if there are no matches it errors
2626
# if however there are no keys that match you get the correct score
27-
run_dict[td.q_id][NULL_RESPONSE_KEY] = 1
27+
run_dict[td.id][NULL_RESPONSE_KEY] = 1
2828

2929
return Run(run_dict)
3030

redisvl/utils/threshold_optimizer/router.py renamed to redisvl/utils/optimize/router.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,22 @@
55
from ranx import Qrels, Run, evaluate
66

77
from redisvl.extensions.router.semantic import SemanticRouter
8-
from redisvl.utils.threshold_optimizer.base import BaseThresholdOptimizer, EvalMetric
9-
from redisvl.utils.threshold_optimizer.schema import TestData
10-
from redisvl.utils.threshold_optimizer.utils import NULL_RESPONSE_KEY, _format_qrels
8+
from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric
9+
from redisvl.utils.optimize.schema import TestData
10+
from redisvl.utils.optimize.utils import NULL_RESPONSE_KEY, _format_qrels
1111

1212

1313
def _generate_run_router(test_data: List[TestData], router: SemanticRouter) -> Run:
1414
"""Format router results into format for ranx Run"""
1515
run_dict: Dict[Any, Any] = {}
1616

1717
for td in test_data:
18-
run_dict[td.q_id] = {}
18+
run_dict[td.id] = {}
1919
route_match = router(td.query)
2020
if route_match and route_match.name == td.query_match:
21-
run_dict[td.q_id][td.query_match] = 1
21+
run_dict[td.id][td.query_match] = 1
2222
else:
23-
run_dict[td.q_id][NULL_RESPONSE_KEY] = 1
23+
run_dict[td.id][NULL_RESPONSE_KEY] = 1
2424

2525
return Run(run_dict)
2626

redisvl/utils/threshold_optimizer/schema.py renamed to redisvl/utils/optimize/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class TestData(BaseModel):
8-
q_id: str = Field(default_factory=lambda: str(ULID()))
8+
id: str = Field(default_factory=lambda: str(ULID()))
99
query: str
1010
query_match: Optional[str]
1111
response: List[dict] = []

redisvl/utils/threshold_optimizer/utils.py renamed to redisvl/utils/optimize/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from ranx import Qrels
44

5-
from redisvl.utils.threshold_optimizer.schema import TestData
5+
from redisvl.utils.optimize.schema import TestData
66

77
NULL_RESPONSE_KEY = "no_match"
88

@@ -13,10 +13,10 @@ def _format_qrels(test_data: List[TestData]) -> Qrels:
1313

1414
for td in test_data:
1515
if td.query_match:
16-
qrels_dict[td.q_id] = {td.query_match: 1}
16+
qrels_dict[td.id] = {td.query_match: 1}
1717
else:
1818
# This is for capturing true negatives from test set
19-
qrels_dict[td.q_id] = {NULL_RESPONSE_KEY: 1}
19+
qrels_dict[td.id] = {NULL_RESPONSE_KEY: 1}
2020

2121
return Qrels(qrels_dict)
2222

0 commit comments

Comments
 (0)