|
14 | 14 | Text, |
15 | 15 | Timestamp, |
16 | 16 | ) |
| 17 | +from redisvl.query.query import VectorRangeQuery |
17 | 18 | from redisvl.redis.utils import array_to_buffer |
18 | 19 |
|
19 | 20 | # TODO expand to multiple schema types and sync + async |
@@ -662,3 +663,129 @@ def test_range_query_normalize_bad_input(index): |
662 | 663 | return_fields=["user", "credit_score", "age", "job", "location"], |
663 | 664 | distance_threshold=1.2, |
664 | 665 | ) |
| 666 | + |
| 667 | +def test_hybrid_policy_batches_mode(index, vector_query): |
| 668 | + """Test vector query with BATCHES hybrid policy.""" |
| 669 | + # Create a filter |
| 670 | + t = Tag("credit_score") == "high" |
| 671 | + |
| 672 | + # Set hybrid policy to BATCHES |
| 673 | + vector_query.set_hybrid_policy("BATCHES") |
| 674 | + vector_query.set_batch_size(2) |
| 675 | + |
| 676 | + # Set the filter |
| 677 | + vector_query.set_filter(t) |
| 678 | + |
| 679 | + # Check query string |
| 680 | + assert "HYBRID_POLICY BATCHES BATCH_SIZE 2" in str(vector_query) |
| 681 | + |
| 682 | + # Execute query |
| 683 | + results = index.query(vector_query) |
| 684 | + |
| 685 | + # Check results - should have filtered to "high" credit scores |
| 686 | + assert len(results) > 0 |
| 687 | + for result in results: |
| 688 | + assert result["credit_score"] == "high" |
| 689 | + |
| 690 | + |
| 691 | +def test_hybrid_policy_adhoc_bf_mode(index, vector_query): |
| 692 | + """Test vector query with ADHOC_BF hybrid policy.""" |
| 693 | + # Create a filter |
| 694 | + t = Tag("credit_score") == "high" |
| 695 | + |
| 696 | + # Set hybrid policy to ADHOC_BF |
| 697 | + vector_query.set_hybrid_policy("ADHOC_BF") |
| 698 | + |
| 699 | + # Set the filter |
| 700 | + vector_query.set_filter(t) |
| 701 | + |
| 702 | + # Check query string |
| 703 | + assert "HYBRID_POLICY ADHOC_BF" in str(vector_query) |
| 704 | + |
| 705 | + # Execute query |
| 706 | + results = index.query(vector_query) |
| 707 | + |
| 708 | + # Check results - should have filtered to "high" credit scores |
| 709 | + assert len(results) > 0 |
| 710 | + for result in results: |
| 711 | + assert result["credit_score"] == "high" |
| 712 | + |
| 713 | + |
| 714 | +def test_range_query_with_epsilon(index): |
| 715 | + """Integration test: Execute range query with epsilon parameter against Redis.""" |
| 716 | + # Create a range query with epsilon |
| 717 | + epsilon_query = VectorRangeQuery( |
| 718 | + vector=[0.1, 0.1, 0.5], |
| 719 | + vector_field_name="user_embedding", |
| 720 | + return_fields=["user", "credit_score", "age", "job"], |
| 721 | + distance_threshold=0.3, |
| 722 | + epsilon=0.5, # Larger than default to get potentially more results |
| 723 | + ) |
| 724 | + |
| 725 | + # Verify query string contains epsilon attribute |
| 726 | + query_string = str(epsilon_query) |
| 727 | + assert "$EPSILON: 0.5" in query_string |
| 728 | + |
| 729 | + # Verify epsilon property is set |
| 730 | + assert epsilon_query.epsilon == 0.5 |
| 731 | + |
| 732 | + # Test setting epsilon |
| 733 | + epsilon_query.set_epsilon(0.1) |
| 734 | + assert epsilon_query.epsilon == 0.1 |
| 735 | + assert "$EPSILON: 0.1" in str(epsilon_query) |
| 736 | + |
| 737 | + # Execute basic query without epsilon to ensure functionality |
| 738 | + basic_query = VectorRangeQuery( |
| 739 | + vector=[0.1, 0.1, 0.5], |
| 740 | + vector_field_name="user_embedding", |
| 741 | + return_fields=["user", "credit_score", "age", "job"], |
| 742 | + distance_threshold=0.2, |
| 743 | + ) |
| 744 | + |
| 745 | + results = index.query(basic_query) |
| 746 | + |
| 747 | + # Check results |
| 748 | + for result in results: |
| 749 | + assert float(result["vector_distance"]) <= 0.2 |
| 750 | + |
| 751 | + |
| 752 | +def test_range_query_with_filter_and_hybrid_policy(index): |
| 753 | + """Integration test: Test construction of a range query with filter and hybrid policy.""" |
| 754 | + # Create a filter for high credit score |
| 755 | + credit_filter = Tag("credit_score") == "high" |
| 756 | + |
| 757 | + # Create a range query with filter and hybrid policy |
| 758 | + query = VectorRangeQuery( |
| 759 | + vector=[0.1, 0.1, 0.5], |
| 760 | + vector_field_name="user_embedding", |
| 761 | + return_fields=["user", "credit_score", "age", "job"], |
| 762 | + filter_expression=credit_filter, |
| 763 | + distance_threshold=0.5, |
| 764 | + hybrid_policy="BATCHES", |
| 765 | + batch_size=2, |
| 766 | + ) |
| 767 | + |
| 768 | + # Check query string and parameters |
| 769 | + query_string = str(query) |
| 770 | + assert "@credit_score:{high}" in query_string |
| 771 | + assert "HYBRID_POLICY" not in query_string |
| 772 | + assert query.hybrid_policy == "BATCHES" |
| 773 | + assert query.batch_size == 2 |
| 774 | + assert query.params["HYBRID_POLICY"] == "BATCHES" |
| 775 | + assert query.params["BATCH_SIZE"] == 2 |
| 776 | + |
| 777 | + # Execute basic query with filter but without hybrid policy |
| 778 | + basic_filter_query = VectorRangeQuery( |
| 779 | + vector=[0.1, 0.1, 0.5], |
| 780 | + vector_field_name="user_embedding", |
| 781 | + return_fields=["user", "credit_score", "age", "job"], |
| 782 | + filter_expression=credit_filter, |
| 783 | + distance_threshold=0.5, |
| 784 | + ) |
| 785 | + |
| 786 | + results = index.query(basic_filter_query) |
| 787 | + |
| 788 | + # Check results |
| 789 | + for result in results: |
| 790 | + assert result["credit_score"] == "high" |
| 791 | + assert float(result["vector_distance"]) <= 0.5 |
0 commit comments