@@ -90,7 +90,47 @@ def _random_search_opt_router(
9090
9191
9292class RouterThresholdOptimizer (BaseThresholdOptimizer ):
93- """Class for optimizing thresholds for a SemanticRouter."""
93+ """
94+ Class for optimizing thresholds for a SemanticRouter.
95+
96+ .. code-block:: python
97+
98+ from redisvl.extensions.router import Route, SemanticRouter
99+ from redisvl.utils.vectorize import HFTextVectorizer
100+ from redisvl.utils.optimize import RouterThresholdOptimizer
101+
102+ routes = [
103+ Route(
104+ name="greeting",
105+ references=["hello", "hi"],
106+ metadata={"type": "greeting"},
107+ distance_threshold=0.5,
108+ ),
109+ Route(
110+ name="farewell",
111+ references=["bye", "goodbye"],
112+ metadata={"type": "farewell"},
113+ distance_threshold=0.5,
114+ ),
115+ ]
116+
117+ router = SemanticRouter(
118+ name="greeting-router",
119+ vectorizer=HFTextVectorizer(),
120+ routes=routes,
121+ redis_url="redis://localhost:6379",
122+ overwrite=True # Blow away any other routing index with this name
123+ )
124+
125+ test_data = [
126+ {"query": "hello", "query_match": "greeting"},
127+ {"query": "goodbye", "query_match": "farewell"},
128+ ...
129+ ]
130+
131+ optimizer = RouterThresholdOptimizer(router, test_data)
132+ optimizer.optimize()
133+ """
94134
95135 def __init__ (
96136 self ,
@@ -108,54 +148,12 @@ def __init__(
108148 grid search.
109149 eval_metric (str): Evaluation metric for threshold optimization.
110150 Defaults to "f1" score.
111-
112- .. code-block:: python
113- from redisvl.extensions.router import Route, SemanticRouter
114- from redisvl.utils.vectorize import HFTextVectorizer
115- from redisvl.utils.optimize import RouterThresholdOptimizer
116-
117- routes = [
118- Route(
119- name="greeting",
120- references=["hello", "hi"],
121- metadata={"type": "greeting"},
122- distance_threshold=0.5,
123- ),
124- Route(
125- name="farewell",
126- references=["bye", "goodbye"],
127- metadata={"type": "farewell"},
128- distance_threshold=0.5,
129- ),
130- ]
131-
132- router = SemanticRouter(
133- name="greeting-router",
134- vectorizer=HFTextVectorizer(),
135- routes=routes,
136- redis_url="redis://localhost:6379",
137- overwrite=True # Blow away any other routing index with this name
138- )
139-
140- test_data = [
141- {"query": "hello", "query_match": "greeting"},
142- {"query": "goodbye", "query_match": "farewell"},
143- ...
144- ]
145-
146- optimizer = RouterThresholdOptimizer(router, test_data)
147- optimizer.optimize()
151+ Raises:
152+ ValueError: If the test_dict not in LabeledData format.
148153 """
149154 super ().__init__ (router , test_dict , opt_fn , eval_metric )
150155
151156 def optimize (self , ** kwargs : Any ):
152- """Optimize thresholds using the provided optimization function for router case.
153-
154- .. code-block:: python
155- from redisvl.utils.optimize import RouterThresholdOptimizer
156-
157- optimizer = RouterThresholdOptimizer(router, test_data)
158- optimizer.optimize(search_step=0.05, max_iterations=50)
159- """
157+ """Optimize kicks of the optimization process for router"""
160158 qrels = _format_qrels (self .test_data )
161159 self .opt_fn (self .optimizable , self .test_data , qrels , self .eval_metric , ** kwargs )
0 commit comments