Skip to content

Commit fe4d7ef

Browse files
committed
change type for baseline ranx
1 parent 759ec1b commit fe4d7ef

File tree

5 files changed

+9
-9
lines changed

5 files changed

+9
-9
lines changed

redisvl/utils/optimize/cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@
1212

1313
def _generate_run_cache(test_data: List[LabeledData], threshold: float) -> Run:
1414
"""Format observed data for evaluation with ranx"""
15-
run_dict: Dict[str, Dict[str, int]] = {}
15+
run_dict: Dict[str, Dict[str, float]] = {}
1616

1717
for td in test_data:
1818
run_dict[td.id] = {}
1919
for res in td.response:
2020
if float(res["vector_distance"]) < threshold:
2121
# value of 1 is irrelevant checks only on match for f1
22-
run_dict[td.id][res["id"]] = 1
22+
run_dict[td.id][res["id"]] = 1.0
2323

2424
if not run_dict[td.id]:
2525
# ranx is a little odd in that if there are no matches it errors
2626
# if however there are no keys that match you get the correct score
27-
run_dict[td.id][NULL_RESPONSE_KEY] = 1
27+
run_dict[td.id][NULL_RESPONSE_KEY] = 1.0
2828

2929
return Run(run_dict)
3030

redisvl/utils/optimize/router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) -
1818
run_dict[td.id] = {}
1919
route_match = router(td.query)
2020
if route_match and route_match.name == td.query_match:
21-
run_dict[td.id][td.query_match] = np.int64(1)
21+
run_dict[td.id][td.query_match] = 1.0
2222
else:
23-
run_dict[td.id][NULL_RESPONSE_KEY] = np.int64(1)
23+
run_dict[td.id][NULL_RESPONSE_KEY] = 1.0
2424

2525
return Run(run_dict)
2626

redisvl/utils/optimize/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import List
22

3-
import numpy as np
43
from ranx import Qrels
54

65
from redisvl.utils.optimize.schema import LabeledData
@@ -14,10 +13,10 @@ def _format_qrels(test_data: List[LabeledData]) -> Qrels:
1413

1514
for td in test_data:
1615
if td.query_match:
17-
qrels_dict[td.id] = {td.query_match: np.int64(1)}
16+
qrels_dict[td.id] = {td.query_match: 1.0}
1817
else:
1918
# This is for capturing true negatives from test set
20-
qrels_dict[td.id] = {NULL_RESPONSE_KEY: np.int64(1)}
19+
qrels_dict[td.id] = {NULL_RESPONSE_KEY: 1.0}
2120

2221
return Qrels(qrels_dict)
2322

schemas/semantic_router.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: test-router-01JSHK4MJ79HH51PS6WEK6M9MF
1+
name: test-router
22
routes:
33
- name: greeting
44
references:

tests/integration/test_semantic_router.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def test_from_dict(semantic_router):
200200

201201
def test_to_yaml(semantic_router):
202202
yaml_file = str(get_base_path().joinpath("../../schemas/semantic_router.yaml"))
203+
semantic_router.name = "test-router"
203204
semantic_router.to_yaml(yaml_file, overwrite=True)
204205
assert pathlib.Path(yaml_file).exists()
205206

0 commit comments

Comments
 (0)