Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
max-parallel: 15
fail-fast: false
matrix:
redis-version: [ '${{ needs.redis_version.outputs.CURRENT }}', '7.2.6', '6.2.16']
redis-version: ['8.0-M02', '${{ needs.redis_version.outputs.CURRENT }}', '7.2.6', '6.2.16']
python-version: ['3.8', '3.12']
parser-backend: ['plain']
event-loop: ['asyncio']
Expand Down
9 changes: 6 additions & 3 deletions redis/commands/search/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from redis.client import NEVER_DECODE, Pipeline
from redis.utils import deprecated_function

from ..helpers import get_protocol_version, parse_to_dict
from ..helpers import get_protocol_version
from ._util import to_string
from .aggregation import AggregateRequest, AggregateResult, Cursor
from .document import Document
from .field import Field
from .indexDefinition import IndexDefinition
from .profileInformation import ProfileInformation
from .query import Query
from .result import Result
from .suggestion import SuggestionParser
Expand Down Expand Up @@ -66,8 +67,10 @@ class SearchCommands:
"""Search commands."""

def _parse_results(self, cmd, res, **kwargs):
if get_protocol_version(self.client) in ["3", 3]:
if get_protocol_version(self.client) in ["3", 3] and cmd != "FT.PROFILE":
return res
elif get_protocol_version(self.client) in ["3", 3] and cmd == "FT.PROFILE":
return ProfileInformation(res)
else:
return self._RESP2_MODULE_CALLBACKS[cmd](res, **kwargs)

Expand Down Expand Up @@ -101,7 +104,7 @@ def _parse_profile(self, res, **kwargs):
with_scores=query._with_scores,
)

return result, parse_to_dict(res[1])
return result, ProfileInformation(res[1])

def _parse_spellcheck(self, res, **kwargs):
corrections = {}
Expand Down
14 changes: 14 additions & 0 deletions redis/commands/search/profileInformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Any


class ProfileInformation:
"""
Wrapper around FT.PROFILE response
"""

def __init__(self, info: Any) -> None:
self._info: Any = info

@property
def info(self) -> Any:
return self._info
16 changes: 16 additions & 0 deletions tests/test_asyncio/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
is_resp2_connection,
skip_if_redis_enterprise,
skip_if_resp_version,
skip_if_server_version_gte,
skip_if_server_version_lt,
skip_ifmodversion_lt,
)

Expand Down Expand Up @@ -1111,6 +1113,7 @@ async def test_get(decoded_r: redis.Redis):
@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_ifmodversion_lt("2.2.0", "search")
@skip_if_server_version_gte("7.9.0")
async def test_config(decoded_r: redis.Redis):
assert await decoded_r.ft().config_set("TIMEOUT", "100")
with pytest.raises(redis.ResponseError):
Expand All @@ -1121,6 +1124,19 @@ async def test_config(decoded_r: redis.Redis):
assert "100" == res["TIMEOUT"]


@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_if_server_version_lt("7.9.0")
async def test_config_with_removed_ftconfig(decoded_r: redis.Redis):
assert await decoded_r.config_set("timeout", "100")
with pytest.raises(redis.ResponseError):
await decoded_r.config_set("timeout", "null")
res = await decoded_r.config_get("*")
assert "100" == res["timeout"]
res = await decoded_r.config_get("timeout")
assert "100" == res["timeout"]


@pytest.mark.redismod
@pytest.mark.onlynoncluster
async def test_aggregations_groupby(decoded_r: redis.Redis):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,7 @@ def try_delete_libs(self, r, *lib_names):

@pytest.mark.onlynoncluster
@skip_if_server_version_lt("7.1.140")
@skip_if_server_version_gte("7.9.0")
def test_tfunction_load_delete(self, stack_r):
self.try_delete_libs(stack_r, "lib1")
lib_code = self.generate_lib_code("lib1")
Expand All @@ -1831,6 +1832,7 @@ def test_tfunction_load_delete(self, stack_r):

@pytest.mark.onlynoncluster
@skip_if_server_version_lt("7.1.140")
@skip_if_server_version_gte("7.9.0")
def test_tfunction_list(self, stack_r):
self.try_delete_libs(stack_r, "lib1", "lib2", "lib3")
assert stack_r.tfunction_load(self.generate_lib_code("lib1"))
Expand Down Expand Up @@ -1861,6 +1863,7 @@ def test_tfunction_list(self, stack_r):

@pytest.mark.onlynoncluster
@skip_if_server_version_lt("7.1.140")
@skip_if_server_version_gte("7.9.0")
def test_tfcall(self, stack_r):
self.try_delete_libs(stack_r, "lib1")
assert stack_r.tfunction_load(self.generate_lib_code("lib1"))
Expand Down Expand Up @@ -4329,6 +4332,7 @@ def test_xgroup_create_mkstream(self, r):
assert r.xinfo_groups(stream) == expected

@skip_if_server_version_lt("7.0.0")
@skip_if_server_version_gte("7.9.0")
def test_xgroup_create_entriesread(self, r: redis.Redis):
stream = "stream"
group = "group"
Expand All @@ -4350,6 +4354,28 @@ def test_xgroup_create_entriesread(self, r: redis.Redis):
]
assert r.xinfo_groups(stream) == expected

@skip_if_server_version_lt("7.9.0")
def test_xgroup_create_entriesread_with_fixed_lag_field(self, r: redis.Redis):
stream = "stream"
group = "group"
r.xadd(stream, {"foo": "bar"})

# no group is setup yet, no info to obtain
assert r.xinfo_groups(stream) == []

assert r.xgroup_create(stream, group, 0, entries_read=7)
expected = [
{
"name": group.encode(),
"consumers": 0,
"pending": 0,
"last-delivered-id": b"0-0",
"entries-read": 7,
"lag": 1,
}
]
assert r.xinfo_groups(stream) == expected

@skip_if_server_version_lt("5.0.0")
def test_xgroup_delconsumer(self, r):
stream = "stream"
Expand Down
146 changes: 128 additions & 18 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
is_resp2_connection,
skip_if_redis_enterprise,
skip_if_resp_version,
skip_if_server_version_gte,
skip_if_server_version_lt,
skip_ifmodversion_lt,
)
Expand Down Expand Up @@ -1007,6 +1008,7 @@ def test_get(client):
@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_ifmodversion_lt("2.2.0", "search")
@skip_if_server_version_gte("7.9.0")
def test_config(client):
assert client.ft().config_set("TIMEOUT", "100")
with pytest.raises(redis.ResponseError):
Expand All @@ -1017,6 +1019,19 @@ def test_config(client):
assert "100" == res["TIMEOUT"]


@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_if_server_version_lt("7.9.0")
def test_config_with_removed_ftconfig(client):
assert client.config_set("timeout", "100")
with pytest.raises(redis.ResponseError):
client.config_set("timeout", "null")
res = client.config_get("*")
assert "100" == res["timeout"]
res = client.config_get("timeout")
assert "100" == res["timeout"]


@pytest.mark.redismod
@pytest.mark.onlynoncluster
def test_aggregations_groupby(client):
Expand Down Expand Up @@ -1571,6 +1586,7 @@ def test_index_definition(client):
@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_if_redis_enterprise()
@skip_if_server_version_gte("7.9.0")
def test_expire(client):
client.ft().create_index((TextField("txt", sortable=True),), temporary=4)
ttl = client.execute_command("ft.debug", "TTL", "idx")
Expand Down Expand Up @@ -2025,6 +2041,8 @@ def test_json_with_jsonpath(client):
@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_if_redis_enterprise()
@skip_if_server_version_gte("7.9.0")
@skip_if_server_version_lt("6.3.0")
def test_profile(client):
client.ft().create_index((TextField("t"),))
client.ft().client.hset("1", "t", "hello")
Expand All @@ -2034,10 +2052,10 @@ def test_profile(client):
q = Query("hello|world").no_content()
if is_resp2_connection(client):
res, det = client.ft().profile(q)
assert det["Iterators profile"]["Counter"] == 2.0
assert len(det["Iterators profile"]["Child iterators"]) == 2
assert det["Iterators profile"]["Type"] == "UNION"
assert det["Parsing time"] < 0.5
det = det.info
assert det[4][1][7] == 2.0
assert det[4][1][1] == "UNION"
assert float(det[1][1]) < 0.5
assert len(res.docs) == 2 # check also the search result

# check using AggregateRequest
Expand All @@ -2047,12 +2065,13 @@ def test_profile(client):
.apply(prefix="startswith(@t, 'hel')")
)
res, det = client.ft().profile(req)
assert det["Iterators profile"]["Counter"] == 2
assert det["Iterators profile"]["Type"] == "WILDCARD"
assert isinstance(det["Parsing time"], float)
det = det.info
assert det[4][1][5] == 2
assert det[4][1][1] == "WILDCARD"
assert len(res.rows) == 2 # check also the search result
else:
res = client.ft().profile(q)
res = res.info
assert res["profile"]["Iterators profile"][0]["Counter"] == 2.0
assert res["profile"]["Iterators profile"][0]["Type"] == "UNION"
assert res["profile"]["Parsing time"] < 0.5
Expand All @@ -2065,6 +2084,7 @@ def test_profile(client):
.apply(prefix="startswith(@t, 'hel')")
)
res = client.ft().profile(req)
res = res.info
assert res["profile"]["Iterators profile"][0]["Counter"] == 2
assert res["profile"]["Iterators profile"][0]["Type"] == "WILDCARD"
assert isinstance(res["profile"]["Parsing time"], float)
Expand All @@ -2073,6 +2093,96 @@ def test_profile(client):

@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_if_redis_enterprise()
@skip_if_server_version_lt("7.9.0")
def test_profile_with_coordinator(client):
client.ft().create_index((TextField("t"),))
client.ft().client.hset("1", "t", "hello")
client.ft().client.hset("2", "t", "world")

# check using Query
q = Query("hello|world").no_content()
if is_resp2_connection(client):
res, det = client.ft().profile(q)
det = det.info
assert det[0] == "Shards"
assert det[2] == "Coordinator"
assert det[1][0][9][7] == 2.0
assert det[1][0][9][1] == "UNION"
assert float(det[1][0][3]) < 0.5
assert len(res.docs) == 2 # check also the search result

# check using AggregateRequest
req = (
aggregations.AggregateRequest("*")
.load("t")
.apply(prefix="startswith(@t, 'hel')")
)
res, det = client.ft().profile(req)
det = det.info
assert det[0] == "Shards"
assert det[2] == "Coordinator"
assert det[1][0][9][5] == 2
assert det[1][0][9][1] == "WILDCARD"
assert len(res.rows) == 2 # check also the search result
else:
res = client.ft().profile(q)
res = res.info
assert res["Profile"]["Shards"][0]["Iterators profile"]["Counter"] == 2.0
assert res["Profile"]["Shards"][0]["Iterators profile"]["Type"] == "UNION"
assert res["Profile"]["Shards"][0]["Parsing time"] < 0.5
assert len(res["Results"]["results"]) == 2 # check also the search result

# check using AggregateRequest
req = (
aggregations.AggregateRequest("*")
.load("t")
.apply(prefix="startswith(@t, 'hel')")
)
res = client.ft().profile(req)
res = res.info
assert res["Profile"]["Shards"][0]["Iterators profile"]["Counter"] == 2
assert res["Profile"]["Shards"][0]["Iterators profile"]["Type"] == "WILDCARD"
assert isinstance(res["Profile"]["Shards"][0]["Parsing time"], float)
assert len(res["Results"]["results"]) == 2 # check also the search result


@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_if_redis_enterprise()
@skip_if_server_version_gte("6.3.0")
def test_profile_with_no_warnings(client):
client.ft().create_index((TextField("t"),))
client.ft().client.hset("1", "t", "hello")
client.ft().client.hset("2", "t", "world")

# check using Query
q = Query("hello|world").no_content()
res, det = client.ft().profile(q)
det = det.info
print(det)
assert det[3][1][7] == 2.0
assert det[3][1][1] == "UNION"
assert float(det[1][1]) < 0.5
assert len(res.docs) == 2 # check also the search result

# check using AggregateRequest
req = (
aggregations.AggregateRequest("*")
.load("t")
.apply(prefix="startswith(@t, 'hel')")
)
res, det = client.ft().profile(req)
det = det.info
assert det[3][1][5] == 2
assert det[3][1][1] == "WILDCARD"
assert len(res.rows) == 2 # check also the search result


@pytest.mark.redismod
@pytest.mark.onlynoncluster
@skip_if_server_version_gte("7.9.0")
@skip_if_server_version_lt("6.3.0")
def test_profile_limited(client):
client.ft().create_index((TextField("t"),))
client.ft().client.hset("1", "t", "hello")
Expand All @@ -2083,18 +2193,14 @@ def test_profile_limited(client):
q = Query("%hell% hel*")
if is_resp2_connection(client):
res, det = client.ft().profile(q, limited=True)
assert (
det["Iterators profile"]["Child iterators"][0]["Child iterators"]
== "The number of iterators in the union is 3"
)
assert (
det["Iterators profile"]["Child iterators"][1]["Child iterators"]
== "The number of iterators in the union is 4"
)
assert det["Iterators profile"]["Type"] == "INTERSECT"
det = det.info
assert det[4][1][7][9] == "The number of iterators in the union is 3"
assert det[4][1][8][9] == "The number of iterators in the union is 4"
assert det[4][1][1] == "INTERSECT"
assert len(res.docs) == 3 # check also the search result
else:
res = client.ft().profile(q, limited=True)
res = res.info
iterators_profile = res["profile"]["Iterators profile"]
assert (
iterators_profile[0]["Child iterators"][0]["Child iterators"]
Expand All @@ -2110,6 +2216,8 @@ def test_profile_limited(client):

@pytest.mark.redismod
@skip_ifmodversion_lt("2.4.3", "search")
@skip_if_server_version_gte("7.9.0")
@skip_if_server_version_lt("6.3.0")
def test_profile_query_params(client):
client.ft().create_index(
(
Expand All @@ -2125,13 +2233,15 @@ def test_profile_query_params(client):
q = Query(query).return_field("__v_score").sort_by("__v_score", True)
if is_resp2_connection(client):
res, det = client.ft().profile(q, query_params={"vec": "aaaaaaaa"})
assert det["Iterators profile"]["Counter"] == 2.0
assert det["Iterators profile"]["Type"] == "VECTOR"
det = det.info
assert det[4][1][5] == 2.0
assert det[4][1][1] == "VECTOR"
assert res.total == 2
assert "a" == res.docs[0].id
assert "0" == res.docs[0].__getattribute__("__v_score")
else:
res = client.ft().profile(q, query_params={"vec": "aaaaaaaa"})
res = res.info
assert res["profile"]["Iterators profile"][0]["Counter"] == 2
assert res["profile"]["Iterators profile"][0]["Type"] == "VECTOR"
assert res["total_results"] == 2
Expand Down
Loading