Skip to content

Commit 15cfeba

Browse files
authored
Merge branch '0.5.0' into feat/RAAE-599/distance-normalization
2 parents 9c88680 + 0b3a5ce commit 15cfeba

File tree

13 files changed

+791
-31
lines changed

13 files changed

+791
-31
lines changed

redisvl/query/query.py

Lines changed: 270 additions & 5 deletions
Large diffs are not rendered by default.

redisvl/utils/log.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import coloredlogs
55

6-
# constants for logging
76
coloredlogs.DEFAULT_DATE_FORMAT = "%H:%M:%S"
87
coloredlogs.DEFAULT_LOG_FORMAT = "%(asctime)s %(name)s %(levelname)s %(message)s"
98

@@ -15,5 +14,16 @@ def get_logger(name, log_level="info", fmt=None):
1514
name = "RedisVL" if log_level == "debug" else name
1615

1716
logger = logging.getLogger(name)
18-
coloredlogs.install(level=log_level, logger=logger, fmt=fmt, stream=sys.stdout)
17+
18+
# Only configure this specific logger, not the root logger
19+
# Check if the logger already has handlers to respect existing configuration
20+
if not logger.handlers:
21+
coloredlogs.install(
22+
level=log_level,
23+
logger=logger, # Pass the specific logger
24+
fmt=fmt,
25+
stream=sys.stdout,
26+
isatty=True, # Only use colors when supported
27+
reconfigure=False, # Don't reconfigure existing loggers
28+
)
1929
return logger

redisvl/utils/optimize/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric
22
from redisvl.utils.optimize.cache import CacheThresholdOptimizer
33
from redisvl.utils.optimize.router import RouterThresholdOptimizer
4-
from redisvl.utils.optimize.schema import TestData
4+
from redisvl.utils.optimize.schema import LabeledData
55

66
__all__ = [
77
"CacheThresholdOptimizer",
88
"RouterThresholdOptimizer",
99
"EvalMetric",
1010
"BaseThresholdOptimizer",
11-
"TestData",
11+
"LabeledData",
1212
]

redisvl/utils/optimize/cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from redisvl.extensions.llmcache.semantic import SemanticCache
77
from redisvl.query import RangeQuery
88
from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric
9-
from redisvl.utils.optimize.schema import TestData
9+
from redisvl.utils.optimize.schema import LabeledData
1010
from redisvl.utils.optimize.utils import NULL_RESPONSE_KEY, _format_qrels
1111

1212

13-
def _generate_run_cache(test_data: List[TestData], threshold: float) -> Run:
13+
def _generate_run_cache(test_data: List[LabeledData], threshold: float) -> Run:
1414
"""Format observed data for evaluation with ranx"""
1515
run_dict: Dict[str, Dict[str, int]] = {}
1616

@@ -30,7 +30,7 @@ def _generate_run_cache(test_data: List[TestData], threshold: float) -> Run:
3030

3131

3232
def _eval_cache(
33-
test_data: List[TestData], threshold: float, qrels: Qrels, metric: str
33+
test_data: List[LabeledData], threshold: float, qrels: Qrels, metric: str
3434
) -> float:
3535
"""Formats run data and evaluates supported metric"""
3636
run = _generate_run_cache(test_data, threshold)
@@ -46,7 +46,7 @@ def _get_best_threshold(metrics: dict) -> float:
4646

4747

4848
def _grid_search_opt_cache(
49-
cache: SemanticCache, test_data: List[TestData], eval_metric: EvalMetric
49+
cache: SemanticCache, test_data: List[LabeledData], eval_metric: EvalMetric
5050
):
5151
"""Evaluates all thresholds in linspace for cache to determine optimal"""
5252
thresholds = np.linspace(0.01, 0.8, 60)

redisvl/utils/optimize/router.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
from redisvl.extensions.router.semantic import SemanticRouter
88
from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric
9-
from redisvl.utils.optimize.schema import TestData
9+
from redisvl.utils.optimize.schema import LabeledData
1010
from redisvl.utils.optimize.utils import NULL_RESPONSE_KEY, _format_qrels
1111

1212

13-
def _generate_run_router(test_data: List[TestData], router: SemanticRouter) -> Run:
13+
def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) -> Run:
1414
"""Format router results into format for ranx Run"""
1515
run_dict: Dict[Any, Any] = {}
1616

@@ -26,7 +26,7 @@ def _generate_run_router(test_data: List[TestData], router: SemanticRouter) -> R
2626

2727

2828
def _eval_router(
29-
router: SemanticRouter, test_data: List[TestData], qrels: Qrels, eval_metric: str
29+
router: SemanticRouter, test_data: List[LabeledData], qrels: Qrels, eval_metric: str
3030
) -> float:
3131
"""Evaluate acceptable metric given run and qrels data"""
3232
run = _generate_run_router(test_data, router)
@@ -55,7 +55,7 @@ def _router_random_search(
5555

5656
def _random_search_opt_router(
5757
router: SemanticRouter,
58-
test_data: List[TestData],
58+
test_data: List[LabeledData],
5959
qrels: Qrels,
6060
eval_metric: EvalMetric,
6161
**kwargs: Any,
@@ -67,12 +67,15 @@ def _random_search_opt_router(
6767
best_thresholds = router.route_thresholds
6868

6969
max_iterations = kwargs.get("max_iterations", 20)
70+
search_step = kwargs.get("search_step", 0.10)
7071

7172
for _ in range(max_iterations):
7273
route_names = router.route_names
7374
route_thresholds = router.route_thresholds
7475
thresholds = _router_random_search(
75-
route_names=route_names, route_thresholds=route_thresholds
76+
route_names=route_names,
77+
route_thresholds=route_thresholds,
78+
search_step=search_step,
7679
)
7780
router.update_route_thresholds(thresholds)
7881
score = _eval_router(router, test_data, qrels, eval_metric.value)

redisvl/utils/optimize/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ulid import ULID
55

66

7-
class TestData(BaseModel):
7+
class LabeledData(BaseModel):
88
id: str = Field(default_factory=lambda: str(ULID()))
99
query: str
1010
query_match: Optional[str]

redisvl/utils/optimize/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
from ranx import Qrels
44

5-
from redisvl.utils.optimize.schema import TestData
5+
from redisvl.utils.optimize.schema import LabeledData
66

77
NULL_RESPONSE_KEY = "no_match"
88

99

10-
def _format_qrels(test_data: List[TestData]) -> Qrels:
10+
def _format_qrels(test_data: List[LabeledData]) -> Qrels:
1111
"""Utility function for creating qrels for evaluation with ranx"""
1212
qrels_dict = {}
1313

@@ -21,6 +21,6 @@ def _format_qrels(test_data: List[TestData]) -> Qrels:
2121
return Qrels(qrels_dict)
2222

2323

24-
def _validate_test_dict(test_dict: List[dict]) -> List[TestData]:
24+
def _validate_test_dict(test_dict: List[dict]) -> List[LabeledData]:
2525
"""Convert/validate test_dict for use in optimizer"""
26-
return [TestData(**d) for d in test_dict]
26+
return [LabeledData(**d) for d in test_dict]

tests/integration/test_query.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Text,
1515
Timestamp,
1616
)
17+
from redisvl.query.query import VectorRangeQuery
1718
from redisvl.redis.utils import array_to_buffer
1819

1920
# TODO expand to multiple schema types and sync + async
@@ -662,3 +663,129 @@ def test_range_query_normalize_bad_input(index):
662663
return_fields=["user", "credit_score", "age", "job", "location"],
663664
distance_threshold=1.2,
664665
)
666+
667+
def test_hybrid_policy_batches_mode(index, vector_query):
668+
"""Test vector query with BATCHES hybrid policy."""
669+
# Create a filter
670+
t = Tag("credit_score") == "high"
671+
672+
# Set hybrid policy to BATCHES
673+
vector_query.set_hybrid_policy("BATCHES")
674+
vector_query.set_batch_size(2)
675+
676+
# Set the filter
677+
vector_query.set_filter(t)
678+
679+
# Check query string
680+
assert "HYBRID_POLICY BATCHES BATCH_SIZE 2" in str(vector_query)
681+
682+
# Execute query
683+
results = index.query(vector_query)
684+
685+
# Check results - should have filtered to "high" credit scores
686+
assert len(results) > 0
687+
for result in results:
688+
assert result["credit_score"] == "high"
689+
690+
691+
def test_hybrid_policy_adhoc_bf_mode(index, vector_query):
692+
"""Test vector query with ADHOC_BF hybrid policy."""
693+
# Create a filter
694+
t = Tag("credit_score") == "high"
695+
696+
# Set hybrid policy to ADHOC_BF
697+
vector_query.set_hybrid_policy("ADHOC_BF")
698+
699+
# Set the filter
700+
vector_query.set_filter(t)
701+
702+
# Check query string
703+
assert "HYBRID_POLICY ADHOC_BF" in str(vector_query)
704+
705+
# Execute query
706+
results = index.query(vector_query)
707+
708+
# Check results - should have filtered to "high" credit scores
709+
assert len(results) > 0
710+
for result in results:
711+
assert result["credit_score"] == "high"
712+
713+
714+
def test_range_query_with_epsilon(index):
715+
"""Integration test: Execute range query with epsilon parameter against Redis."""
716+
# Create a range query with epsilon
717+
epsilon_query = VectorRangeQuery(
718+
vector=[0.1, 0.1, 0.5],
719+
vector_field_name="user_embedding",
720+
return_fields=["user", "credit_score", "age", "job"],
721+
distance_threshold=0.3,
722+
epsilon=0.5, # Larger than default to get potentially more results
723+
)
724+
725+
# Verify query string contains epsilon attribute
726+
query_string = str(epsilon_query)
727+
assert "$EPSILON: 0.5" in query_string
728+
729+
# Verify epsilon property is set
730+
assert epsilon_query.epsilon == 0.5
731+
732+
# Test setting epsilon
733+
epsilon_query.set_epsilon(0.1)
734+
assert epsilon_query.epsilon == 0.1
735+
assert "$EPSILON: 0.1" in str(epsilon_query)
736+
737+
# Execute basic query without epsilon to ensure functionality
738+
basic_query = VectorRangeQuery(
739+
vector=[0.1, 0.1, 0.5],
740+
vector_field_name="user_embedding",
741+
return_fields=["user", "credit_score", "age", "job"],
742+
distance_threshold=0.2,
743+
)
744+
745+
results = index.query(basic_query)
746+
747+
# Check results
748+
for result in results:
749+
assert float(result["vector_distance"]) <= 0.2
750+
751+
752+
def test_range_query_with_filter_and_hybrid_policy(index):
753+
"""Integration test: Test construction of a range query with filter and hybrid policy."""
754+
# Create a filter for high credit score
755+
credit_filter = Tag("credit_score") == "high"
756+
757+
# Create a range query with filter and hybrid policy
758+
query = VectorRangeQuery(
759+
vector=[0.1, 0.1, 0.5],
760+
vector_field_name="user_embedding",
761+
return_fields=["user", "credit_score", "age", "job"],
762+
filter_expression=credit_filter,
763+
distance_threshold=0.5,
764+
hybrid_policy="BATCHES",
765+
batch_size=2,
766+
)
767+
768+
# Check query string and parameters
769+
query_string = str(query)
770+
assert "@credit_score:{high}" in query_string
771+
assert "HYBRID_POLICY" not in query_string
772+
assert query.hybrid_policy == "BATCHES"
773+
assert query.batch_size == 2
774+
assert query.params["HYBRID_POLICY"] == "BATCHES"
775+
assert query.params["BATCH_SIZE"] == 2
776+
777+
# Execute basic query with filter but without hybrid policy
778+
basic_filter_query = VectorRangeQuery(
779+
vector=[0.1, 0.1, 0.5],
780+
vector_field_name="user_embedding",
781+
return_fields=["user", "credit_score", "age", "job"],
782+
filter_expression=credit_filter,
783+
distance_threshold=0.5,
784+
)
785+
786+
results = index.query(basic_filter_query)
787+
788+
# Check results
789+
for result in results:
790+
assert result["credit_score"] == "high"
791+
assert float(result["vector_distance"]) <= 0.5

tests/integration/test_threshold_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_routes_different_distance_thresholds_optimizer_default(
111111

112112
# now run optimizer
113113
router_optimizer = RouterThresholdOptimizer(router, test_data_optimization)
114-
router_optimizer.optimize(max_iterations=10)
114+
router_optimizer.optimize(max_iterations=10, search_step=0.5)
115115

116116
# test that it updated thresholds beyond the null case
117117
for route in routes:
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import logging
2+
import sys
3+
4+
# Set up custom logging
5+
handler = logging.StreamHandler(sys.stdout)
6+
handler.setFormatter(
7+
logging.Formatter(
8+
"%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)s] %(message)s"
9+
)
10+
)
11+
12+
# Configure root logger
13+
root_logger = logging.getLogger()
14+
root_logger.handlers = [handler]
15+
root_logger.setLevel(logging.INFO)
16+
17+
# Log before import
18+
app_logger = logging.getLogger("app")
19+
app_logger.info("PRE_IMPORT_FORMAT")
20+
21+
# Import RedisVL
22+
from redisvl.query.filter import Text # noqa: E402, F401
23+
24+
# Log after import
25+
app_logger.info("POST_IMPORT_FORMAT")

0 commit comments

Comments
 (0)