Skip to content

Commit d0c86ed

Browse files
committed
update doc strings to work
1 parent 14b6749 commit d0c86ed

File tree

2 files changed

+82
-85
lines changed

2 files changed

+82
-85
lines changed

redisvl/utils/optimize/cache.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,41 @@ def _grid_search_opt_cache(
7171

7272

7373
class CacheThresholdOptimizer(BaseThresholdOptimizer):
74-
"""Class for optimizing thresholds for a SemanticCache."""
74+
"""
75+
Class for optimizing thresholds for a SemanticCache.
76+
77+
.. code-block:: python
78+
79+
from redisvl.extensions.llmcache import SemanticCache
80+
from redisvl.utils.optimize import CacheThresholdOptimizer
81+
82+
sem_cache = SemanticCache(
83+
name="sem_cache", # underlying search index name
84+
redis_url="redis://localhost:6379", # redis connection url string
85+
distance_threshold=0.5 # semantic cache distance threshold
86+
)
87+
88+
paris_key = sem_cache.store(prompt="what is the capital of france?", response="paris")
89+
rabat_key = sem_cache.store(prompt="what is the capital of morocco?", response="rabat")
90+
91+
test_data = [
92+
{
93+
"query": "What's the capital of Britain?",
94+
"query_match": ""
95+
},
96+
{
97+
"query": "What's the capital of France??",
98+
"query_match": paris_key
99+
},
100+
{
101+
"query": "What's the capital city of Morocco?",
102+
"query_match": rabat_key
103+
},
104+
]
105+
106+
optimizer = CacheThresholdOptimizer(sem_cache, test_data)
107+
optimizer.optimize()
108+
"""
75109

76110
def __init__(
77111
self,
@@ -90,46 +124,11 @@ def __init__(
90124
eval_metric (str): Evaluation metric for threshold optimization.
91125
Defaults to "f1" score.
92126
93-
.. code-block:: python
94-
from redisvl.extensions.llmcache import SemanticCache
95-
from redisvl.utils.optimize import CacheThresholdOptimizer
96-
97-
sem_cache = SemanticCache(
98-
name="sem_cache", # underlying search index name
99-
redis_url="redis://localhost:6379", # redis connection url string
100-
distance_threshold=0.5 # semantic cache distance threshold
101-
)
102-
103-
paris_key = sem_cache.store(prompt="what is the capital of france?", response="paris")
104-
rabat_key = sem_cache.store(prompt="what is the capital of morocco?", response="rabat")
105-
106-
test_data = [
107-
{
108-
"query": "What's the capital of Britain?",
109-
"query_match": ""
110-
},
111-
{
112-
"query": "What's the capital of France??",
113-
"query_match": paris_key
114-
},
115-
{
116-
"query": "What's the capital city of Morocco?",
117-
"query_match": rabat_key
118-
},
119-
]
120-
121-
optimizer = CacheThresholdOptimizer(sem_cache, test_data)
122-
optimizer.optimize()
127+
Raises:
128+
ValueError: If the test_dict not in LabeledData format.
123129
"""
124130
super().__init__(cache, test_dict, opt_fn, eval_metric)
125131

126132
def optimize(self, **kwargs: Any):
127-
"""Optimize thresholds using the provided optimization function for cache case.
128-
129-
.. code-block:: python
130-
from redisvl.utils.optimize import CacheThresholdOptimizer
131-
132-
optimizer = CacheThresholdOptimizer(semantic_cache, test_data)
133-
optimizer.optimize(*kwargs)
134-
"""
133+
"""Optimize thresholds using the provided optimization function for cache case."""
135134
self.opt_fn(self.optimizable, self.test_data, self.eval_metric, **kwargs)

redisvl/utils/optimize/router.py

Lines changed: 44 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,47 @@ def _random_search_opt_router(
9090

9191

9292
class RouterThresholdOptimizer(BaseThresholdOptimizer):
93-
"""Class for optimizing thresholds for a SemanticRouter."""
93+
"""
94+
Class for optimizing thresholds for a SemanticRouter.
95+
96+
.. code-block:: python
97+
98+
from redisvl.extensions.router import Route, SemanticRouter
99+
from redisvl.utils.vectorize import HFTextVectorizer
100+
from redisvl.utils.optimize import RouterThresholdOptimizer
101+
102+
routes = [
103+
Route(
104+
name="greeting",
105+
references=["hello", "hi"],
106+
metadata={"type": "greeting"},
107+
distance_threshold=0.5,
108+
),
109+
Route(
110+
name="farewell",
111+
references=["bye", "goodbye"],
112+
metadata={"type": "farewell"},
113+
distance_threshold=0.5,
114+
),
115+
]
116+
117+
router = SemanticRouter(
118+
name="greeting-router",
119+
vectorizer=HFTextVectorizer(),
120+
routes=routes,
121+
redis_url="redis://localhost:6379",
122+
overwrite=True # Blow away any other routing index with this name
123+
)
124+
125+
test_data = [
126+
{"query": "hello", "query_match": "greeting"},
127+
{"query": "goodbye", "query_match": "farewell"},
128+
...
129+
]
130+
131+
optimizer = RouterThresholdOptimizer(router, test_data)
132+
optimizer.optimize()
133+
"""
94134

95135
def __init__(
96136
self,
@@ -108,54 +148,12 @@ def __init__(
108148
grid search.
109149
eval_metric (str): Evaluation metric for threshold optimization.
110150
Defaults to "f1" score.
111-
112-
.. code-block:: python
113-
from redisvl.extensions.router import Route, SemanticRouter
114-
from redisvl.utils.vectorize import HFTextVectorizer
115-
from redisvl.utils.optimize import RouterThresholdOptimizer
116-
117-
routes = [
118-
Route(
119-
name="greeting",
120-
references=["hello", "hi"],
121-
metadata={"type": "greeting"},
122-
distance_threshold=0.5,
123-
),
124-
Route(
125-
name="farewell",
126-
references=["bye", "goodbye"],
127-
metadata={"type": "farewell"},
128-
distance_threshold=0.5,
129-
),
130-
]
131-
132-
router = SemanticRouter(
133-
name="greeting-router",
134-
vectorizer=HFTextVectorizer(),
135-
routes=routes,
136-
redis_url="redis://localhost:6379",
137-
overwrite=True # Blow away any other routing index with this name
138-
)
139-
140-
test_data = [
141-
{"query": "hello", "query_match": "greeting"},
142-
{"query": "goodbye", "query_match": "farewell"},
143-
...
144-
]
145-
146-
optimizer = RouterThresholdOptimizer(router, test_data)
147-
optimizer.optimize()
151+
Raises:
152+
ValueError: If the test_dict not in LabeledData format.
148153
"""
149154
super().__init__(router, test_dict, opt_fn, eval_metric)
150155

151156
def optimize(self, **kwargs: Any):
152-
"""Optimize thresholds using the provided optimization function for router case.
153-
154-
.. code-block:: python
155-
from redisvl.utils.optimize import RouterThresholdOptimizer
156-
157-
optimizer = RouterThresholdOptimizer(router, test_data)
158-
optimizer.optimize(search_step=0.05, max_iterations=50)
159-
"""
157+
"""Optimize kicks of the optimization process for router"""
160158
qrels = _format_qrels(self.test_data)
161159
self.opt_fn(self.optimizable, self.test_data, qrels, self.eval_metric, **kwargs)

0 commit comments

Comments
 (0)