Skip to content

Commit e846aab

Browse files
committed
[Async] Multi Exec on cluster
1 parent 0cc7e85 commit e846aab

File tree

2 files changed

+290
-37
lines changed

2 files changed

+290
-37
lines changed

redis/asyncio/cluster.py

Lines changed: 85 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import random
44
import socket
55
import threading
6+
import time
67
import warnings
78
from abc import ABC, abstractmethod
89
from copy import copy
@@ -19,7 +20,9 @@
1920
Tuple,
2021
Type,
2122
TypeVar,
22-
Union, Set,
23+
Union,
24+
Set,
25+
Coroutine,
2326
)
2427

2528
from redis._parsers import AsyncCommandsParser, Encoder
@@ -65,7 +68,11 @@
6568
ResponseError,
6669
SlotNotCoveredError,
6770
TimeoutError,
68-
TryAgainError, CrossSlotTransactionError, WatchError, ExecAbortError, InvalidPipelineStack,
71+
TryAgainError,
72+
CrossSlotTransactionError,
73+
WatchError,
74+
ExecAbortError,
75+
InvalidPipelineStack,
6976
)
7077
from redis.typing import AnyKeyT, EncodableT, KeyT
7178
from redis.utils import (
@@ -947,6 +954,30 @@ def lock(
947954
raise_on_release_error=raise_on_release_error,
948955
)
949956

957+
async def transaction(
958+
self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs
959+
):
960+
"""
961+
Convenience method for executing the callable `func` as a transaction
962+
while watching all keys specified in `watches`. The 'func' callable
963+
should expect a single argument which is a Pipeline object.
964+
"""
965+
shard_hint = kwargs.pop("shard_hint", None)
966+
value_from_callable = kwargs.pop("value_from_callable", False)
967+
watch_delay = kwargs.pop("watch_delay", None)
968+
async with self.pipeline(True, shard_hint) as pipe:
969+
while True:
970+
try:
971+
if watches:
972+
await pipe.watch(*watches)
973+
func_value = await func(pipe)
974+
exec_value = await pipe.execute()
975+
return func_value if value_from_callable else exec_value
976+
except WatchError:
977+
if watch_delay is not None and watch_delay > 0:
978+
time.sleep(watch_delay)
979+
continue
980+
950981

951982
class ClusterNode:
952983
"""
@@ -1508,17 +1539,17 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
15081539
| Existing :class:`~.RedisCluster` client
15091540
"""
15101541

1511-
__slots__ = ("_command_stack", "cluster_client")
1542+
__slots__ = ("cluster_client",)
15121543

15131544
def __init__(
1514-
self,
1515-
client: RedisCluster,
1516-
transaction: Optional[bool] = None
1545+
self, client: RedisCluster, transaction: Optional[bool] = None
15171546
) -> None:
15181547
self.cluster_client = client
15191548
self._transaction = transaction
15201549
self._execution_strategy: ExecutionStrategy = (
1521-
PipelineStrategy(self) if not self._transaction else TransactionStrategy(self)
1550+
PipelineStrategy(self)
1551+
if not self._transaction
1552+
else TransactionStrategy(self)
15221553
)
15231554

15241555
async def initialize(self) -> "ClusterPipeline":
@@ -1585,7 +1616,9 @@ async def execute(
15851616
can't be mapped to a slot
15861617
"""
15871618
try:
1588-
return await self._execution_strategy.execute(raise_on_error, allow_redirections)
1619+
return await self._execution_strategy.execute(
1620+
raise_on_error, allow_redirections
1621+
)
15891622
finally:
15901623
await self.reset()
15911624

@@ -1628,7 +1661,6 @@ async def unlink(self, *names):
16281661
def mset_nonatomic(
16291662
self, mapping: Mapping[AnyKeyT, EncodableT]
16301663
) -> "ClusterPipeline":
1631-
16321664
return self._execution_strategy.mset_nonatomic(mapping)
16331665

16341666

@@ -1663,7 +1695,7 @@ async def initialize(self) -> "ClusterPipeline":
16631695

16641696
@abstractmethod
16651697
def execute_command(
1666-
self, *args: Union[KeyT, EncodableT], **kwargs: Any
1698+
self, *args: Union[KeyT, EncodableT], **kwargs: Any
16671699
) -> "ClusterPipeline":
16681700
"""
16691701
Append a raw command to the pipeline.
@@ -1748,7 +1780,6 @@ async def unlink(self, *names):
17481780

17491781

17501782
class AbstractStrategy(ExecutionStrategy):
1751-
17521783
def __init__(self, pipe: ClusterPipeline) -> None:
17531784
self._pipe: ClusterPipeline = pipe
17541785
self._command_queue: List["PipelineCommand"] = []
@@ -1779,7 +1810,9 @@ async def initialize(self) -> "ClusterPipeline":
17791810
self._command_queue = []
17801811
return self._pipe
17811812

1782-
def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "ClusterPipeline":
1813+
def execute_command(
1814+
self, *args: Union[KeyT, EncodableT], **kwargs: Any
1815+
) -> "ClusterPipeline":
17831816
self._command_queue.append(
17841817
PipelineCommand(len(self._command_queue), *args, **kwargs)
17851818
)
@@ -1797,11 +1830,15 @@ def _annotate_exception(self, exception, number, command):
17971830
exception.args = (msg,) + exception.args[1:]
17981831

17991832
@abstractmethod
1800-
def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> "ClusterPipeline":
1833+
def mset_nonatomic(
1834+
self, mapping: Mapping[AnyKeyT, EncodableT]
1835+
) -> "ClusterPipeline":
18011836
pass
18021837

18031838
@abstractmethod
1804-
async def execute(self, raise_on_error: bool = True, allow_redirections: bool = True) -> List[Any]:
1839+
async def execute(
1840+
self, raise_on_error: bool = True, allow_redirections: bool = True
1841+
) -> List[Any]:
18051842
pass
18061843

18071844
@abstractmethod
@@ -1828,11 +1865,14 @@ async def discard(self):
18281865
async def unlink(self, *names):
18291866
pass
18301867

1868+
18311869
class PipelineStrategy(AbstractStrategy):
18321870
def __init__(self, pipe: ClusterPipeline) -> None:
18331871
super().__init__(pipe)
18341872

1835-
def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> "ClusterPipeline":
1873+
def mset_nonatomic(
1874+
self, mapping: Mapping[AnyKeyT, EncodableT]
1875+
) -> "ClusterPipeline":
18361876
encoder = self._pipe.cluster_client.encoder
18371877

18381878
slots_pairs = {}
@@ -1845,7 +1885,9 @@ def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> "ClusterPipel
18451885

18461886
return self._pipe
18471887

1848-
async def execute(self, raise_on_error: bool = True, allow_redirections: bool = True) -> List[Any]:
1888+
async def execute(
1889+
self, raise_on_error: bool = True, allow_redirections: bool = True
1890+
) -> List[Any]:
18491891
if not self._command_queue:
18501892
return []
18511893

@@ -1993,6 +2035,7 @@ async def unlink(self, *names):
19932035

19942036
return self.execute_command("UNLINK", names[0])
19952037

2038+
19962039
class TransactionStrategy(AbstractStrategy):
19972040
NO_SLOTS_COMMANDS = {"UNWATCH"}
19982041
IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"}
@@ -2018,7 +2061,9 @@ def __init__(self, pipe: ClusterPipeline) -> None:
20182061
RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS
20192062
)
20202063

2021-
def _get_client_and_connection_for_transaction(self) -> Tuple[ClusterNode, Connection]:
2064+
def _get_client_and_connection_for_transaction(
2065+
self,
2066+
) -> Tuple[ClusterNode, Connection]:
20222067
"""
20232068
Find a connection for a pipeline transaction.
20242069
@@ -2065,7 +2110,9 @@ def runner():
20652110

20662111
return response
20672112

2068-
async def _execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> Any:
2113+
async def _execute_command(
2114+
self, *args: Union[KeyT, EncodableT], **kwargs: Any
2115+
) -> Any:
20692116
if self._pipe.cluster_client._initialize:
20702117
await self._pipe.cluster_client.initialize()
20712118

@@ -2074,7 +2121,7 @@ async def _execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any)
20742121
slot_number = await self._pipe.cluster_client._determine_slot(*args)
20752122

20762123
if (
2077-
self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
2124+
self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS
20782125
) and not self._explicit_transaction:
20792126
if args[0] == "WATCH":
20802127
self._validate_watch()
@@ -2118,7 +2165,12 @@ async def _get_connection_and_send_command(self, *args, **options):
21182165
)
21192166

21202167
async def _send_command_parse_response(
2121-
self, connection: Connection, redis_node: ClusterNode, command_name, *args, **options
2168+
self,
2169+
connection: Connection,
2170+
redis_node: ClusterNode,
2171+
command_name,
2172+
*args,
2173+
**options,
21222174
):
21232175
"""
21242176
Send a command and parse the response
@@ -2145,8 +2197,10 @@ async def _reinitialize_on_error(self, error):
21452197

21462198
self._pipe.cluster_client.reinitialize_counter += 1
21472199
if (
2148-
self._pipe.cluster_client.reinitialize_steps
2149-
and self._pipe.cluster_client.reinitialize_counter % self._pipe.cluster_client.reinitialize_steps == 0
2200+
self._pipe.cluster_client.reinitialize_steps
2201+
and self._pipe.cluster_client.reinitialize_counter
2202+
% self._pipe.cluster_client.reinitialize_steps
2203+
== 0
21502204
):
21512205
await self._pipe.cluster_client.nodes_manager.initialize()
21522206
self.reinitialize_counter = 0
@@ -2164,10 +2218,14 @@ def _raise_first_error(self, responses, stack):
21642218
self._annotate_exception(r, cmd.position + 1, cmd.args)
21652219
raise r
21662220

2167-
def mset_nonatomic(self, mapping: Mapping[AnyKeyT, EncodableT]) -> "ClusterPipeline":
2168-
raise NotImplementedError('Method is not supported in transactional context.')
2221+
def mset_nonatomic(
2222+
self, mapping: Mapping[AnyKeyT, EncodableT]
2223+
) -> "ClusterPipeline":
2224+
raise NotImplementedError("Method is not supported in transactional context.")
21692225

2170-
async def execute(self, raise_on_error: bool = True, allow_redirections: bool = True) -> List[Any]:
2226+
async def execute(
2227+
self, raise_on_error: bool = True, allow_redirections: bool = True
2228+
) -> List[Any]:
21712229
stack = self._command_queue
21722230
if not stack and (not self._watching or not self._pipeline_slots):
21732231
return []
@@ -2197,7 +2255,7 @@ async def _execute_transaction(
21972255
stack = chain(
21982256
[PipelineCommand(0, "MULTI")],
21992257
stack,
2200-
[PipelineCommand(0, 'EXEC')],
2258+
[PipelineCommand(0, "EXEC")],
22012259
)
22022260
commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs]
22032261
packed_commands = connection.pack_commands(commands)
@@ -2334,4 +2392,4 @@ async def discard(self):
23342392
await self.reset()
23352393

23362394
async def unlink(self, *names):
2337-
return self.execute_command("UNLINK", *names)
2395+
return self.execute_command("UNLINK", *names)

0 commit comments

Comments
 (0)