Skip to content

Commit 229b3fa

Browse files
committed
add unit tests
1 parent 4102d04 commit 229b3fa

File tree

6 files changed

+92
-10
lines changed

6 files changed

+92
-10
lines changed

redisvl/extensions/threshold_optimizer/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from enum import Enum
33
from typing import Any, Callable, Dict, List, TypeVar
44

5-
from redisvl.extensions.threshold_optimizer.utils import validate_test_dict
5+
from redisvl.extensions.threshold_optimizer.utils import _validate_test_dict
66

77

88
class EvalMetric(Enum):
@@ -47,7 +47,7 @@ def __init__(
4747
eval_fn: Function to evaluate performance
4848
opt_fn: Function to perform optimization
4949
"""
50-
self.test_data = validate_test_dict(test_dict)
50+
self.test_data = _validate_test_dict(test_dict)
5151
self.optimizable = optimizable
5252
self.eval_metric = EvalMetric(eval_metric)
5353
self.opt_fn = opt_fn

redisvl/extensions/threshold_optimizer/cache.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
EvalMetric,
1010
)
1111
from redisvl.extensions.threshold_optimizer.schema import TestData
12-
from redisvl.extensions.threshold_optimizer.utils import NULL_RESPONSE_KEY, format_qrels
12+
from redisvl.extensions.threshold_optimizer.utils import (
13+
NULL_RESPONSE_KEY,
14+
_format_qrels,
15+
)
1316
from redisvl.query import RangeQuery
1417

1518

@@ -63,7 +66,7 @@ def _grid_search_opt_cache(
6366
res = cache.index.query(query)
6467
td.response = res
6568

66-
qrels = format_qrels(test_data)
69+
qrels = _format_qrels(test_data)
6770

6871
for threshold in thresholds:
6972
score = _eval_cache(test_data, threshold, qrels, eval_metric.value)

redisvl/extensions/threshold_optimizer/router.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
EvalMetric,
1111
)
1212
from redisvl.extensions.threshold_optimizer.schema import TestData
13-
from redisvl.extensions.threshold_optimizer.utils import NULL_RESPONSE_KEY, format_qrels
13+
from redisvl.extensions.threshold_optimizer.utils import (
14+
NULL_RESPONSE_KEY,
15+
_format_qrels,
16+
)
1417

1518

1619
def _generate_run_router(test_data: List[TestData], router: SemanticRouter) -> Run:
@@ -98,5 +101,5 @@ def __init__(
98101

99102
def optimize(self, **kwargs: Any):
100103
"""Optimize thresholds using the provided optimization function for router case."""
101-
qrels = format_qrels(self.test_data)
104+
qrels = _format_qrels(self.test_data)
102105
self.opt_fn(self.optimizable, self.test_data, qrels, self.eval_metric, **kwargs)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Optional
22

33
from pydantic import BaseModel, Field
44
from ulid import ULID
@@ -7,5 +7,5 @@
77
class TestData(BaseModel):
88
q_id: str = Field(default_factory=lambda: str(ULID()))
99
query: str
10-
query_match: str | None
10+
query_match: Optional[str]
1111
response: List[dict] = []

redisvl/extensions/threshold_optimizer/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
NULL_RESPONSE_KEY = "no_match"
88

99

10-
def format_qrels(test_data: List[TestData]) -> Qrels:
10+
def _format_qrels(test_data: List[TestData]) -> Qrels:
11+
"""Utility function for creating qrels for evaluation with ranx"""
1112
qrels_dict = {}
1213

1314
for td in test_data:
@@ -20,5 +21,6 @@ def format_qrels(test_data: List[TestData]) -> Qrels:
2021
return Qrels(qrels_dict)
2122

2223

23-
def validate_test_dict(test_dict: List[dict]) -> List[TestData]:
24+
def _validate_test_dict(test_dict: List[dict]) -> List[TestData]:
25+
"""Convert/validate test_dict for use in optimizer"""
2426
return [TestData(**d) for d in test_dict]
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import pytest
2+
from ranx import evaluate
3+
4+
from redisvl.extensions.threshold_optimizer.cache import _generate_run_cache
5+
from redisvl.extensions.threshold_optimizer.schema import TestData
6+
from redisvl.extensions.threshold_optimizer.utils import (
7+
NULL_RESPONSE_KEY,
8+
_format_qrels,
9+
)
10+
11+
# Note: these tests are not intended to test ranx but to test that our data formatting for the package is correct
12+
13+
14+
def test_known_precision_case():
15+
"""
16+
Test case with known precision value.
17+
18+
Setup:
19+
- 2 queries
20+
- Query 1 expects doc1, gets doc1 and doc2 (precision 0.5)
21+
- Query 2 expects doc3, gets doc3 (precision 1.0)
22+
Expected overall precision: 0.75
23+
"""
24+
# Setup test data
25+
test_data = [
26+
TestData(
27+
query="test query 1",
28+
query_match="doc1",
29+
response=[
30+
{"id": "doc1", "vector_distance": 0.2},
31+
{"id": "doc2", "vector_distance": 0.3},
32+
],
33+
),
34+
TestData(
35+
query="test query 2",
36+
query_match="doc3",
37+
response=[
38+
{"id": "doc3", "vector_distance": 0.2},
39+
{"id": "doc4", "vector_distance": 0.8},
40+
],
41+
),
42+
]
43+
44+
# Create qrels (ground truth)
45+
qrels = _format_qrels(test_data)
46+
47+
threshold = 0.4
48+
run = _generate_run_cache(test_data, threshold)
49+
50+
# Calculate precision using ranx
51+
precision = evaluate(qrels, run, "precision")
52+
assert precision == 0.75 # (0.5 + 1.0) / 2
53+
54+
55+
def test_known_precision_with_no_matches():
56+
"""Test case where some queries have no matches."""
57+
test_data = [
58+
TestData(
59+
query="test query 2",
60+
query_match="", # Expecting no match
61+
response=[],
62+
),
63+
]
64+
65+
# Create qrels
66+
qrels = _format_qrels(test_data)
67+
68+
# Generate run with threshold that excludes all docs for first query
69+
threshold = 0.3
70+
run = _generate_run_cache(test_data, threshold)
71+
72+
# Calculate precision
73+
precision = evaluate(qrels, run, "precision")
74+
assert precision == 1.0 # (0.0 + 1.0) / 2

0 commit comments

Comments
 (0)