diff --git a/redis/commands/core.py b/redis/commands/core.py index df76eafed0..271f640dec 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -3,6 +3,7 @@ import datetime import hashlib import warnings +from enum import Enum from typing import ( TYPE_CHECKING, Any, @@ -44,6 +45,10 @@ TimeoutSecT, ZScoreBoundT, ) +from redis.utils import ( + deprecated_function, + extract_expire_flags, +) from .helpers import list_or_args @@ -1837,10 +1842,10 @@ def getdel(self, name: KeyT) -> ResponseT: def getex( self, name: KeyT, - ex: Union[ExpiryT, None] = None, - px: Union[ExpiryT, None] = None, - exat: Union[AbsExpiryT, None] = None, - pxat: Union[AbsExpiryT, None] = None, + ex: Optional[ExpiryT] = None, + px: Optional[ExpiryT] = None, + exat: Optional[AbsExpiryT] = None, + pxat: Optional[AbsExpiryT] = None, persist: bool = False, ) -> ResponseT: """ @@ -1863,7 +1868,6 @@ def getex( For more information see https://redis.io/commands/getex """ - opset = {ex, px, exat, pxat} if len(opset) > 2 or len(opset) > 1 and persist: raise DataError( @@ -1871,33 +1875,12 @@ def getex( "and ``persist`` are mutually exclusive." ) - pieces: list[EncodableT] = [] - # similar to set command - if ex is not None: - pieces.append("EX") - if isinstance(ex, datetime.timedelta): - ex = int(ex.total_seconds()) - pieces.append(ex) - if px is not None: - pieces.append("PX") - if isinstance(px, datetime.timedelta): - px = int(px.total_seconds() * 1000) - pieces.append(px) - # similar to pexpireat command - if exat is not None: - pieces.append("EXAT") - if isinstance(exat, datetime.datetime): - exat = int(exat.timestamp()) - pieces.append(exat) - if pxat is not None: - pieces.append("PXAT") - if isinstance(pxat, datetime.datetime): - pxat = int(pxat.timestamp() * 1000) - pieces.append(pxat) + exp_options: list[EncodableT] = extract_expire_flags(ex, px, exat, pxat) + if persist: - pieces.append("PERSIST") + exp_options.append("PERSIST") - return self.execute_command("GETEX", name, *pieces) + return self.execute_command("GETEX", name, *exp_options) def __getitem__(self, name: KeyT): """ @@ -2255,14 +2238,14 @@ def set( self, name: KeyT, value: EncodableT, - ex: Union[ExpiryT, None] = None, - px: Union[ExpiryT, None] = None, + ex: Optional[ExpiryT] = None, + px: Optional[ExpiryT] = None, nx: bool = False, xx: bool = False, keepttl: bool = False, get: bool = False, - exat: Union[AbsExpiryT, None] = None, - pxat: Union[AbsExpiryT, None] = None, + exat: Optional[AbsExpiryT] = None, + pxat: Optional[AbsExpiryT] = None, ) -> ResponseT: """ Set the value at key ``name`` to ``value`` @@ -2292,36 +2275,21 @@ def set( For more information see https://redis.io/commands/set """ + opset = {ex, px, exat, pxat} + if len(opset) > 2 or len(opset) > 1 and keepttl: + raise DataError( + "``ex``, ``px``, ``exat``, ``pxat``, " + "and ``keepttl`` are mutually exclusive." + ) + + if nx and xx: + raise DataError("``nx`` and ``xx`` are mutually exclusive.") + pieces: list[EncodableT] = [name, value] options = {} - if ex is not None: - pieces.append("EX") - if isinstance(ex, datetime.timedelta): - pieces.append(int(ex.total_seconds())) - elif isinstance(ex, int): - pieces.append(ex) - elif isinstance(ex, str) and ex.isdigit(): - pieces.append(int(ex)) - else: - raise DataError("ex must be datetime.timedelta or int") - if px is not None: - pieces.append("PX") - if isinstance(px, datetime.timedelta): - pieces.append(int(px.total_seconds() * 1000)) - elif isinstance(px, int): - pieces.append(px) - else: - raise DataError("px must be datetime.timedelta or int") - if exat is not None: - pieces.append("EXAT") - if isinstance(exat, datetime.datetime): - exat = int(exat.timestamp()) - pieces.append(exat) - if pxat is not None: - pieces.append("PXAT") - if isinstance(pxat, datetime.datetime): - pxat = int(pxat.timestamp() * 1000) - pieces.append(pxat) + + pieces.extend(extract_expire_flags(ex, px, exat, pxat)) + if keepttl: pieces.append("KEEPTTL") @@ -4940,6 +4908,16 @@ def pfmerge(self, dest: KeyT, *sources: KeyT) -> ResponseT: AsyncHyperlogCommands = HyperlogCommands +class HashDataPersistOptions(Enum): + # set the value for each provided key to each + # provided value only if all do not already exist. + FNX = "FNX" + + # set the value for each provided key to each + # provided value only if all already exist. + FXX = "FXX" + + class HashCommands(CommandsProtocol): """ Redis commands for Hash data type. @@ -4980,6 +4958,80 @@ def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: """ return self.execute_command("HGETALL", name, keys=[name]) + def hgetdel( + self, name: str, *keys: str + ) -> Union[ + Awaitable[Optional[List[Union[str, bytes]]]], Optional[List[Union[str, bytes]]] + ]: + """ + Return the value of ``key`` within the hash ``name`` and + delete the field in the hash. + This command is similar to HGET, except for the fact that it also deletes + the key on success from the hash with the provided ```name```. + + Available since Redis 8.0 + For more information see https://redis.io/commands/hgetdel + """ + if len(keys) == 0: + raise DataError("'hgetdel' should have at least one key provided") + + return self.execute_command("HGETDEL", name, "FIELDS", len(keys), *keys) + + def hgetex( + self, + name: KeyT, + *keys: str, + ex: Optional[ExpiryT] = None, + px: Optional[ExpiryT] = None, + exat: Optional[AbsExpiryT] = None, + pxat: Optional[AbsExpiryT] = None, + persist: bool = False, + ) -> Union[ + Awaitable[Optional[List[Union[str, bytes]]]], Optional[List[Union[str, bytes]]] + ]: + """ + Return the values of ``key`` and ``keys`` within the hash ``name`` + and optionally set their expiration. + + ``ex`` sets an expire flag on ``kyes`` for ``ex`` seconds. + + ``px`` sets an expire flag on ``keys`` for ``px`` milliseconds. + + ``exat`` sets an expire flag on ``keys`` for ``ex`` seconds, + specified in unix time. + + ``pxat`` sets an expire flag on ``keys`` for ``ex`` milliseconds, + specified in unix time. + + ``persist`` remove the time to live associated with the ``keys``. + + Available since Redis 8.0 + For more information see https://redis.io/commands/hgetex + """ + if not keys: + raise DataError("'hgetex' should have at least one key provided") + + opset = {ex, px, exat, pxat} + if len(opset) > 2 or len(opset) > 1 and persist: + raise DataError( + "``ex``, ``px``, ``exat``, ``pxat``, " + "and ``persist`` are mutually exclusive." + ) + + exp_options: list[EncodableT] = extract_expire_flags(ex, px, exat, pxat) + + if persist: + exp_options.append("PERSIST") + + return self.execute_command( + "HGETEX", + name, + *exp_options, + "FIELDS", + len(keys), + *keys, + ) + def hincrby( self, name: str, key: str, amount: int = 1 ) -> Union[Awaitable[int], int]: @@ -5034,8 +5086,10 @@ def hset( For more information see https://redis.io/commands/hset """ + if key is None and not mapping and not items: raise DataError("'hset' with no key value pairs") + pieces = [] if items: pieces.extend(items) @@ -5047,6 +5101,89 @@ def hset( return self.execute_command("HSET", name, *pieces) + def hsetex( + self, + name: str, + key: Optional[str] = None, + value: Optional[str] = None, + mapping: Optional[dict] = None, + items: Optional[list] = None, + ex: Optional[ExpiryT] = None, + px: Optional[ExpiryT] = None, + exat: Optional[AbsExpiryT] = None, + pxat: Optional[AbsExpiryT] = None, + data_persist_option: Optional[HashDataPersistOptions] = None, + keepttl: bool = False, + ) -> Union[Awaitable[int], int]: + """ + Set ``key`` to ``value`` within hash ``name`` + + ``mapping`` accepts a dict of key/value pairs that will be + added to hash ``name``. + + ``items`` accepts a list of key/value pairs that will be + added to hash ``name``. + + ``ex`` sets an expire flag on ``keys`` for ``ex`` seconds. + + ``px`` sets an expire flag on ``keys`` for ``px`` milliseconds. + + ``exat`` sets an expire flag on ``keys`` for ``ex`` seconds, + specified in unix time. + + ``pxat`` sets an expire flag on ``keys`` for ``ex`` milliseconds, + specified in unix time. + + ``data_persist_option`` can be set to ``FNX`` or ``FXX`` to control the + behavior of the command. + ``FNX`` will set the value for each provided key to each + provided value only if all do not already exist. + ``FXX`` will set the value for each provided key to each + provided value only if all already exist. + + ``keepttl`` if True, retain the time to live associated with the keys. + + Returns the number of fields that were added. + + Available since Redis 8.0 + For more information see https://redis.io/commands/hsetex + """ + if key is None and not mapping and not items: + raise DataError("'hsetex' with no key value pairs") + + if items and len(items) % 2 != 0: + raise DataError( + "'hsetex' with odd number of items. " + "'items' must contain a list of key/value pairs." + ) + + opset = {ex, px, exat, pxat} + if len(opset) > 2 or len(opset) > 1 and keepttl: + raise DataError( + "``ex``, ``px``, ``exat``, ``pxat``, " + "and ``keepttl`` are mutually exclusive." + ) + + exp_options: list[EncodableT] = extract_expire_flags(ex, px, exat, pxat) + if data_persist_option: + exp_options.append(data_persist_option.value) + + if keepttl: + exp_options.append("KEEPTTL") + + pieces = [] + if items: + pieces.extend(items) + if key is not None: + pieces.extend((key, value)) + if mapping: + for pair in mapping.items(): + pieces.extend(pair) + + return self.execute_command( + "HSETEX", name, *exp_options, "FIELDS", int(len(pieces) / 2), *pieces + ) + def hsetnx(self, name: str, key: str, value: str) -> Union[Awaitable[bool], bool]: """ Set ``key`` to ``value`` within hash ``name`` if ``key`` does not @@ -5056,6 +5193,11 @@ def hsetnx(self, name: str, key: str, value: str) -> Union[Awaitable[bool], bool """ return self.execute_command("HSETNX", name, key, value) + @deprecated_function( + version="4.0.0", + reason="Use 'hset' instead.", + name="hmset", + ) def hmset(self, name: str, mapping: dict) -> Union[Awaitable[str], str]: """ Set key to value within hash ``name`` for each corresponding @@ -5063,12 +5205,6 @@ def hmset(self, name: str, mapping: dict) -> Union[Awaitable[str], str]: For more information see https://redis.io/commands/hmset """ - warnings.warn( - f"{self.__class__.__name__}.hmset() is deprecated. " - f"Use {self.__class__.__name__}.hset() instead.", - DeprecationWarning, - stacklevel=2, - ) if not mapping: raise DataError("'hmset' with 'mapping' of length 0") items = [] diff --git a/redis/utils.py b/redis/utils.py index 66465636a1..9d9b4a9580 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,7 +1,11 @@ +import datetime import logging from contextlib import contextmanager from functools import wraps -from typing import Any, Dict, Mapping, Union +from typing import Any, Dict, List, Mapping, Optional, Union + +from redis.exceptions import DataError +from redis.typing import AbsExpiryT, EncodableT, ExpiryT try: import hiredis # noqa @@ -257,3 +261,40 @@ def ensure_string(key): return key else: raise TypeError("Key must be either a string or bytes") + + +def extract_expire_flags( + ex: Optional[ExpiryT] = None, + px: Optional[ExpiryT] = None, + exat: Optional[AbsExpiryT] = None, + pxat: Optional[AbsExpiryT] = None, +) -> List[EncodableT]: + exp_options: list[EncodableT] = [] + if ex is not None: + exp_options.append("EX") + if isinstance(ex, datetime.timedelta): + exp_options.append(int(ex.total_seconds())) + elif isinstance(ex, int): + exp_options.append(ex) + elif isinstance(ex, str) and ex.isdigit(): + exp_options.append(int(ex)) + else: + raise DataError("ex must be datetime.timedelta or int") + elif px is not None: + exp_options.append("PX") + if isinstance(px, datetime.timedelta): + exp_options.append(int(px.total_seconds() * 1000)) + elif isinstance(px, int): + exp_options.append(px) + else: + raise DataError("px must be datetime.timedelta or int") + elif exat is not None: + if isinstance(exat, datetime.datetime): + exat = int(exat.timestamp()) + exp_options.extend(["EXAT", exat]) + elif pxat is not None: + if isinstance(pxat, datetime.datetime): + pxat = int(pxat.timestamp() * 1000) + exp_options.extend(["PXAT", pxat]) + + return exp_options diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 08bd5810f4..bfb6855a0f 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -31,6 +31,7 @@ skip_if_server_version_lt, skip_unless_arch_bits, ) +from tests.test_asyncio.test_utils import redis_server_time if sys.version_info >= (3, 11, 3): from asyncio import timeout as async_timeout @@ -77,12 +78,6 @@ async def slowlog(r: redis.Redis): await r.config_set("slowlog-max-len", old_max_legnth_value) -async def redis_server_time(client: redis.Redis): - seconds, milliseconds = await client.time() - timestamp = float(f"{seconds}.{milliseconds}") - return datetime.datetime.fromtimestamp(timestamp) - - async def get_stream_message(client: redis.Redis, stream: str, message_id: str): """Fetch a stream message and format it as a (message_id, fields) pair""" response = await client.xrange(stream, min=message_id, max=message_id) @@ -2328,12 +2323,8 @@ async def test_hmget(self, r: redis.Redis): assert await r.hmget("a", "a", "b", "c") == [b"1", b"2", b"3"] async def test_hmset(self, r: redis.Redis): - warning_message = ( - r"^Redis(?:Cluster)*\.hmset\(\) is deprecated\. " - r"Use Redis(?:Cluster)*\.hset\(\) instead\.$" - ) h = {b"a": b"1", b"b": b"2", b"c": b"3"} - with pytest.warns(DeprecationWarning, match=warning_message): + with pytest.warns(DeprecationWarning): assert await r.hmset("a", h) assert await r.hgetall("a") == h diff --git a/tests/test_asyncio/test_hash.py b/tests/test_asyncio/test_hash.py index 15e426673b..4fbc02c5fe 100644 --- a/tests/test_asyncio/test_hash.py +++ b/tests/test_asyncio/test_hash.py @@ -2,7 +2,12 @@ import math from datetime import datetime, timedelta +import pytest + +from redis import exceptions +from redis.commands.core import HashDataPersistOptions from tests.conftest import skip_if_server_version_lt +from tests.test_asyncio.test_utils import redis_server_time @skip_if_server_version_lt("7.3.240") @@ -299,3 +304,274 @@ async def test_pttl_multiple_fields_mixed_conditions(r): result = await r.hpttl("test:hash", "field1", "field2", "field3") assert 30 * 60000 - 10000 < result[0] <= 30 * 60000 assert result[1:] == [-1, -2] + + +@skip_if_server_version_lt("7.9.0") +async def test_hgetdel(r): + await r.delete("test:hash") + await r.hset("test:hash", "foo", "bar", mapping={"1": 1, "2": 2}) + assert await r.hgetdel("test:hash", "foo", "1") == [b"bar", b"1"] + assert await r.hget("test:hash", "foo") is None + assert await r.hget("test:hash", "1") is None + assert await r.hget("test:hash", "2") == b"2" + assert await r.hgetdel("test:hash", "foo", "1") == [None, None] + assert await r.hget("test:hash", "2") == b"2" + + with pytest.raises(exceptions.DataError): + await r.hgetdel("test:hash") + + +@skip_if_server_version_lt("7.9.0") +async def test_hgetex_no_expiration(r): + await r.delete("test:hash") + await r.hset( + "b", "foo", "bar", mapping={"1": 1, "2": 2, "3": "three", "4": b"four"} + ) + + assert await r.hgetex("b", "foo", "1", "4") == [b"bar", b"1", b"four"] + assert await r.hgetex("b", "foo") == [b"bar"] + assert await r.httl("b", "foo", "1", "4") == [-1, -1, -1] + + +@skip_if_server_version_lt("7.9.0") +async def test_hgetex_expiration_configs(r): + await r.delete("test:hash") + await r.hset( + "test:hash", "foo", "bar", mapping={"1": 1, "3": "three", "4": b"four"} + ) + + test_keys = ["foo", "1", "4"] + # test get with multiple fields with expiration set through 'ex' + assert await r.hgetex("test:hash", *test_keys, ex=10) == [ + b"bar", + b"1", + b"four", + ] + ttls = await r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 10 + + # test get with multiple fields removing expiration settings with 'persist' + assert await r.hgetex("test:hash", *test_keys, persist=True) == [ + b"bar", + b"1", + b"four", + ] + assert await r.httl("test:hash", *test_keys) == [-1, -1, -1] + + # test get with multiple fields with expiration set through 'px' + assert await r.hgetex("test:hash", *test_keys, px=6000) == [ + b"bar", + b"1", + b"four", + ] + ttls = await r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 6 + + # test get single field with expiration set through 'pxat' + expire_at = await redis_server_time(r) + timedelta(minutes=1) + assert await r.hgetex("test:hash", "foo", pxat=expire_at) == [b"bar"] + assert (await r.httl("test:hash", "foo"))[0] <= 61 + + # test get single field with expiration set through 'exat' + expire_at = await redis_server_time(r) + timedelta(seconds=10) + assert await r.hgetex("test:hash", "foo", exat=expire_at) == [b"bar"] + assert (await r.httl("test:hash", "foo"))[0] <= 10 + + +@skip_if_server_version_lt("7.9.0") +async def test_hgetex_validate_expired_fields_removed(r): + await r.delete("test:hash") + await r.hset( + "test:hash", "foo", "bar", mapping={"1": 1, "3": "three", "4": b"four"} + ) + + # test get multiple fields with expiration set + # validate that expired fields are removed + assert await r.hgetex("test:hash", "foo", "1", "3", ex=1) == [ + b"bar", + b"1", + b"three", + ] + await asyncio.sleep(1.1) + assert await r.hgetex("test:hash", "foo", "1", "3") == [None, None, None] + assert await r.httl("test:hash", "foo", "1", "3") == [-2, -2, -2] + assert await r.hgetex("test:hash", "4") == [b"four"] + + +@skip_if_server_version_lt("7.9.0") +async def test_hgetex_invalid_inputs(r): + with pytest.raises(exceptions.DataError): + await r.hgetex("b", "foo", ex=10, persist=True) + + with pytest.raises(exceptions.DataError): + await r.hgetex("b", "foo", ex=10.0, persist=True) + + with pytest.raises(exceptions.DataError): + await r.hgetex("b", "foo", ex=10, px=6000) + + with pytest.raises(exceptions.DataError): + await r.hgetex("b", ex=10) + + +@skip_if_server_version_lt("7.9.0") +async def test_hsetex_no_expiration(r): + await r.delete("test:hash") + + # # set items from mapping without expiration + assert await r.hsetex("test:hash", None, None, mapping={"1": 1, "4": b"four"}) == 1 + assert await r.httl("test:hash", "foo", "1", "4") == [-2, -1, -1] + assert await r.hgetex("test:hash", "foo", "1") == [None, b"1"] + + +@skip_if_server_version_lt("7.9.0") +async def test_hsetex_expiration_ex_and_keepttl(r): + await r.delete("test:hash") + + # set items from key/value provided + # combined with mapping and items with expiration - testing ex field + assert ( + await r.hsetex( + "test:hash", + "foo", + "bar", + mapping={"1": 1, "2": "2"}, + items=["i1", 11, "i2", 22], + ex=10, + ) + == 1 + ) + test_keys = ["foo", "1", "2", "i1", "i2"] + ttls = await r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 10 + + assert await r.hgetex("test:hash", *test_keys) == [ + b"bar", + b"1", + b"2", + b"11", + b"22", + ] + await asyncio.sleep(1.1) + # validate keepttl + assert await r.hsetex("test:hash", "foo", "bar1", keepttl=True) == 1 + assert 0 < (await r.httl("test:hash", "foo"))[0] < 10 + + +@skip_if_server_version_lt("7.9.0") +async def test_hsetex_expiration_px(r): + await r.delete("test:hash") + # set items from key/value provided and mapping + # with expiration - testing px field + assert ( + await r.hsetex("test:hash", "foo", "bar", mapping={"1": 1, "2": "2"}, px=60000) + == 1 + ) + test_keys = ["foo", "1", "2"] + ttls = await r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 60 + + assert await r.hgetex("test:hash", *test_keys) == [b"bar", b"1", b"2"] + + +@skip_if_server_version_lt("7.9.0") +async def test_hsetex_expiration_pxat_and_fnx(r): + await r.delete("test:hash") + assert ( + await r.hsetex("test:hash", "foo", "bar", mapping={"1": 1, "2": "2"}, ex=30) + == 1 + ) + + expire_at = await redis_server_time(r) + timedelta(minutes=1) + assert ( + await r.hsetex( + "test:hash", + "foo", + "bar1", + mapping={"new": "ok"}, + pxat=expire_at, + data_persist_option=HashDataPersistOptions.FNX, + ) + == 0 + ) + ttls = await r.httl("test:hash", "foo", "new") + assert ttls[0] <= 30 + assert ttls[1] == -2 + + assert await r.hgetex("test:hash", "foo", "1", "new") == [b"bar", b"1", None] + assert ( + await r.hsetex( + "test:hash", + "foo_new", + "bar1", + mapping={"new": "ok"}, + pxat=expire_at, + data_persist_option=HashDataPersistOptions.FNX, + ) + == 1 + ) + ttls = await r.httl("test:hash", "foo", "new") + for ttl in ttls: + assert ttl <= 61 + assert await r.hgetex("test:hash", "foo", "foo_new", "new") == [ + b"bar", + b"bar1", + b"ok", + ] + + +@skip_if_server_version_lt("7.9.0") +async def test_hsetex_expiration_exat_and_fxx(r): + await r.delete("test:hash") + assert ( + await r.hsetex("test:hash", "foo", "bar", mapping={"1": 1, "2": "2"}, ex=30) + == 1 + ) + + expire_at = await redis_server_time(r) + timedelta(seconds=10) + assert ( + await r.hsetex( + "test:hash", + "foo", + "bar1", + mapping={"new": "ok"}, + exat=expire_at, + data_persist_option=HashDataPersistOptions.FXX, + ) + == 0 + ) + ttls = await r.httl("test:hash", "foo", "new") + assert 10 < ttls[0] <= 30 + assert ttls[1] == -2 + + assert await r.hgetex("test:hash", "foo", "1", "new") == [b"bar", b"1", None] + assert ( + await r.hsetex( + "test:hash", + "foo", + "bar1", + mapping={"1": "new_value"}, + exat=expire_at, + data_persist_option=HashDataPersistOptions.FXX, + ) + == 1 + ) + assert await r.hgetex("test:hash", "foo", "1") == [b"bar1", b"new_value"] + + +@skip_if_server_version_lt("7.9.0") +async def test_hsetex_invalid_inputs(r): + with pytest.raises(exceptions.DataError): + await r.hsetex("b", "foo", "bar", ex=10.0) + + with pytest.raises(exceptions.DataError): + await r.hsetex("b", None, None) + + with pytest.raises(exceptions.DataError): + await r.hsetex("b", "foo", "bar", items=["i1", 11, "i2"], px=6000) + + with pytest.raises(exceptions.DataError): + await r.hsetex("b", "foo", "bar", ex=10, keepttl=True) diff --git a/tests/test_asyncio/test_utils.py b/tests/test_asyncio/test_utils.py new file mode 100644 index 0000000000..05cad1bfaf --- /dev/null +++ b/tests/test_asyncio/test_utils.py @@ -0,0 +1,8 @@ +from datetime import datetime +import redis + + +async def redis_server_time(client: redis.Redis): + seconds, milliseconds = await client.time() + timestamp = float(f"{seconds}.{milliseconds}") + return datetime.fromtimestamp(timestamp) diff --git a/tests/test_commands.py b/tests/test_commands.py index 5c72a019ba..8758efa771 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -21,6 +21,7 @@ from redis.commands.json.path import Path from redis.commands.search.field import TextField from redis.commands.search.query import Query +from tests.test_utils import redis_server_time from .conftest import ( _get_client, @@ -50,12 +51,6 @@ def cleanup(): r.config_set("slowlog-max-len", 128) -def redis_server_time(client): - seconds, milliseconds = client.time() - timestamp = float(f"{seconds}.{milliseconds}") - return datetime.datetime.fromtimestamp(timestamp) - - def get_stream_message(client, stream, message_id): "Fetch a stream message and format it as a (message_id, fields) pair" response = client.xrange(stream, min=message_id, max=message_id) @@ -3393,13 +3388,8 @@ def test_hmget(self, r): assert r.hmget("a", "a", "b", "c") == [b"1", b"2", b"3"] def test_hmset(self, r): - redis_class = type(r).__name__ - warning_message = ( - r"^{0}\.hmset\(\) is deprecated\. " - r"Use {0}\.hset\(\) instead\.$".format(redis_class) - ) h = {b"a": b"1", b"b": b"2", b"c": b"3"} - with pytest.warns(DeprecationWarning, match=warning_message): + with pytest.warns(DeprecationWarning): assert r.hmset("a", h) assert r.hgetall("a") == h diff --git a/tests/test_hash.py b/tests/test_hash.py index 0422185865..c2a92fb852 100644 --- a/tests/test_hash.py +++ b/tests/test_hash.py @@ -3,7 +3,10 @@ from datetime import datetime, timedelta import pytest +from redis import exceptions +from redis.commands.core import HashDataPersistOptions from tests.conftest import skip_if_server_version_lt +from tests.test_utils import redis_server_time @skip_if_server_version_lt("7.3.240") @@ -368,3 +371,247 @@ def test_hpttl_multiple_fields_mixed_conditions(r): def test_hpttl_nonexistent_key(r): r.delete("test:hash") assert r.hpttl("test:hash", "field1", "field2", "field3") == [-2, -2, -2] + + +@skip_if_server_version_lt("7.9.0") +def test_hgetdel(r): + r.delete("test:hash") + r.hset("test:hash", "foo", "bar", mapping={"1": 1, "2": 2}) + assert r.hgetdel("test:hash", "foo", "1") == [b"bar", b"1"] + assert r.hget("test:hash", "foo") is None + assert r.hget("test:hash", "1") is None + assert r.hget("test:hash", "2") == b"2" + assert r.hgetdel("test:hash", "foo", "1") == [None, None] + assert r.hget("test:hash", "2") == b"2" + + with pytest.raises(exceptions.DataError): + r.hgetdel("test:hash") + + +@skip_if_server_version_lt("7.9.0") +def test_hgetex_no_expiration(r): + r.delete("test:hash") + r.hset("b", "foo", "bar", mapping={"1": 1, "2": 2, "3": "three", "4": b"four"}) + + assert r.hgetex("b", "foo", "1", "4") == [b"bar", b"1", b"four"] + assert r.httl("b", "foo", "1", "4") == [-1, -1, -1] + + +@skip_if_server_version_lt("7.9.0") +def test_hgetex_expiration_configs(r): + r.delete("test:hash") + r.hset("test:hash", "foo", "bar", mapping={"1": 1, "3": "three", "4": b"four"}) + test_keys = ["foo", "1", "4"] + + # test get with multiple fields with expiration set through 'ex' + assert r.hgetex("test:hash", *test_keys, ex=10) == [b"bar", b"1", b"four"] + ttls = r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 10 + + # test get with multiple fields removing expiration settings with 'persist' + assert r.hgetex("test:hash", *test_keys, persist=True) == [ + b"bar", + b"1", + b"four", + ] + assert r.httl("test:hash", *test_keys) == [-1, -1, -1] + + # test get with multiple fields with expiration set through 'px' + assert r.hgetex("test:hash", *test_keys, px=6000) == [b"bar", b"1", b"four"] + ttls = r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 6 + + # test get single field with expiration set through 'pxat' + expire_at = redis_server_time(r) + timedelta(minutes=1) + assert r.hgetex("test:hash", "foo", pxat=expire_at) == [b"bar"] + assert r.httl("test:hash", "foo")[0] <= 61 + + # test get single field with expiration set through 'exat' + expire_at = redis_server_time(r) + timedelta(seconds=10) + assert r.hgetex("test:hash", "foo", exat=expire_at) == [b"bar"] + assert r.httl("test:hash", "foo")[0] <= 10 + + +@skip_if_server_version_lt("7.9.0") +def test_hgetex_validate_expired_fields_removed(r): + r.delete("test:hash") + r.hset("test:hash", "foo", "bar", mapping={"1": 1, "3": "three", "4": b"four"}) + + test_keys = ["foo", "1", "3"] + # test get multiple fields with expiration set + # validate that expired fields are removed + assert r.hgetex("test:hash", *test_keys, ex=1) == [b"bar", b"1", b"three"] + time.sleep(1.1) + assert r.hgetex("test:hash", *test_keys) == [None, None, None] + assert r.httl("test:hash", *test_keys) == [-2, -2, -2] + assert r.hgetex("test:hash", "4") == [b"four"] + + +@skip_if_server_version_lt("7.9.0") +def test_hgetex_invalid_inputs(r): + with pytest.raises(exceptions.DataError): + r.hgetex("b", "foo", "1", "3", ex=10, persist=True) + + with pytest.raises(exceptions.DataError): + r.hgetex("b", "foo", ex=10.0, persist=True) + + with pytest.raises(exceptions.DataError): + r.hgetex("b", "foo", ex=10, px=6000) + + with pytest.raises(exceptions.DataError): + r.hgetex("b", ex=10) + + +@skip_if_server_version_lt("7.9.0") +def test_hsetex_no_expiration(r): + r.delete("test:hash") + + # # set items from mapping without expiration + assert r.hsetex("test:hash", None, None, mapping={"1": 1, "4": b"four"}) == 1 + assert r.httl("test:hash", "foo", "1", "4") == [-2, -1, -1] + assert r.hgetex("test:hash", "foo", "1") == [None, b"1"] + + +@skip_if_server_version_lt("7.9.0") +def test_hsetex_expiration_ex_and_keepttl(r): + r.delete("test:hash") + + # set items from key/value provided + # combined with mapping and items with expiration - testing ex field + assert ( + r.hsetex( + "test:hash", + "foo", + "bar", + mapping={"1": 1, "2": "2"}, + items=["i1", 11, "i2", 22], + ex=10, + ) + == 1 + ) + ttls = r.httl("test:hash", "foo", "1", "2", "i1", "i2") + for ttl in ttls: + assert pytest.approx(ttl) == 10 + + assert r.hgetex("test:hash", "foo", "1", "2", "i1", "i2") == [ + b"bar", + b"1", + b"2", + b"11", + b"22", + ] + time.sleep(1.1) + # validate keepttl + assert r.hsetex("test:hash", "foo", "bar1", keepttl=True) == 1 + assert r.httl("test:hash", "foo")[0] < 10 + + +@skip_if_server_version_lt("7.9.0") +def test_hsetex_expiration_px(r): + r.delete("test:hash") + # set items from key/value provided and mapping + # with expiration - testing px field + assert ( + r.hsetex("test:hash", "foo", "bar", mapping={"1": 1, "2": "2"}, px=60000) == 1 + ) + test_keys = ["foo", "1", "2"] + ttls = r.httl("test:hash", *test_keys) + for ttl in ttls: + assert pytest.approx(ttl) == 60 + assert r.hgetex("test:hash", *test_keys) == [b"bar", b"1", b"2"] + + +@skip_if_server_version_lt("7.9.0") +def test_hsetex_expiration_pxat_and_fnx(r): + r.delete("test:hash") + assert r.hsetex("test:hash", "foo", "bar", mapping={"1": 1, "2": "2"}, ex=30) == 1 + + expire_at = redis_server_time(r) + timedelta(minutes=1) + assert ( + r.hsetex( + "test:hash", + "foo", + "bar1", + mapping={"new": "ok"}, + pxat=expire_at, + data_persist_option=HashDataPersistOptions.FNX, + ) + == 0 + ) + ttls = r.httl("test:hash", "foo", "new") + assert ttls[0] <= 30 + assert ttls[1] == -2 + + assert r.hgetex("test:hash", "foo", "1", "new") == [b"bar", b"1", None] + assert ( + r.hsetex( + "test:hash", + "foo_new", + "bar1", + mapping={"new": "ok"}, + pxat=expire_at, + data_persist_option=HashDataPersistOptions.FNX, + ) + == 1 + ) + ttls = r.httl("test:hash", "foo", "new") + for ttl in ttls: + assert ttl <= 61 + assert r.hgetex("test:hash", "foo", "foo_new", "new") == [ + b"bar", + b"bar1", + b"ok", + ] + + +@skip_if_server_version_lt("7.9.0") +def test_hsetex_expiration_exat_and_fxx(r): + r.delete("test:hash") + assert r.hsetex("test:hash", "foo", "bar", mapping={"1": 1, "2": "2"}, ex=30) == 1 + + expire_at = redis_server_time(r) + timedelta(seconds=10) + assert ( + r.hsetex( + "test:hash", + "foo", + "bar1", + mapping={"new": "ok"}, + exat=expire_at, + data_persist_option=HashDataPersistOptions.FXX, + ) + == 0 + ) + ttls = r.httl("test:hash", "foo", "new") + assert 10 < ttls[0] <= 30 + assert ttls[1] == -2 + + assert r.hgetex("test:hash", "foo", "1", "new") == [b"bar", b"1", None] + assert ( + r.hsetex( + "test:hash", + "foo", + "bar1", + mapping={"1": "new_value"}, + exat=expire_at, + data_persist_option=HashDataPersistOptions.FXX, + ) + == 1 + ) + assert r.hgetex("test:hash", "foo", "1") == [b"bar1", b"new_value"] + + +@skip_if_server_version_lt("7.9.0") +def test_hsetex_invalid_inputs(r): + with pytest.raises(exceptions.DataError): + r.hsetex("b", "foo", "bar", ex=10.0) + + with pytest.raises(exceptions.DataError): + r.hsetex("b", None, None) + + with pytest.raises(exceptions.DataError): + r.hsetex("b", "foo", "bar", items=["i1", 11, "i2"], px=6000) + + with pytest.raises(exceptions.DataError): + r.hsetex("b", "foo", "bar", ex=10, keepttl=True) diff --git a/tests/test_utils.py b/tests/test_utils.py index 764ef5d0a9..75de8dbb9f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,4 @@ +from datetime import datetime import pytest from redis.utils import compare_versions @@ -25,3 +26,9 @@ ) def test_compare_versions(version1, version2, expected_res): assert compare_versions(version1, version2) == expected_res + + +def redis_server_time(client): + seconds, milliseconds = client.time() + timestamp = float(f"{seconds}.{milliseconds}") + return datetime.fromtimestamp(timestamp)