diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index fa7ad69a..a270951f 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -335,7 +335,6 @@ def test_text_query_with_string_filter(): assert "AND" not in query_string_wildcard -@pytest.mark.skip("Test is flaking") def test_text_query_word_weights(): # verify word weights get added into the raw Redis query syntax query = TextQuery( @@ -344,10 +343,44 @@ def test_text_query_word_weights(): text_weights={"alpha": 2, "delta": 0.555, "gamma": 0.95}, ) - assert ( - str(query) - == "@description:(query | string | alpha=>{$weight:2} | bravo | delta=>{$weight:0.555} | tango | alpha=>{$weight:2}) SCORER BM25STD WITHSCORES DIALECT 2 LIMIT 0 10" - ) + # Check query components with structural guarantees, + # not exact token ordering (which is non-deterministic). + query_str = str(query) + + # Description clause is properly delimited + assert "@description:(" in query_str + desc_start = query_str.index("@description:(") + desc_close = query_str.index(")", desc_start) + desc_clause = query_str[desc_start + len("@description:(") : desc_close] + + # Weighted terms appear inside the description clause + assert "delta=>{$weight:0.555}" in desc_clause + + # alpha appears twice and both occurrences are weighted + alpha_weighted = "alpha=>{$weight:2}" + assert desc_clause.count(alpha_weighted) == 2 + # Ensure no unweighted 'alpha' tokens slipped through + idx = 0 + while True: + idx = desc_clause.find("alpha", idx) + if idx == -1: + break + assert desc_clause.startswith(alpha_weighted, idx) + idx += len("alpha") + + # Unweighted terms are present + for term in ["query", "string", "bravo", "tango"]: + assert term in desc_clause + + # Post-query modifiers follow after the closing paren in expected order + suffix = query_str[desc_close + 1 :] + suffix_stripped = suffix.lstrip() + assert suffix_stripped.startswith("SCORER BM25STD") + scorer_idx = suffix_stripped.index("SCORER BM25STD") + withscores_idx = suffix_stripped.index("WITHSCORES") + dialect_idx = suffix_stripped.index("DIALECT 2") + limit_idx = suffix_stripped.index("LIMIT 0 10") + assert scorer_idx <= withscores_idx <= dialect_idx <= limit_idx # raise an error if weights are not positive floats with pytest.raises(ValueError):