diff --git a/benchmarks/rwlock.py b/benchmarks/rwlock.py new file mode 100644 index 0000000000..499d021545 --- /dev/null +++ b/benchmarks/rwlock.py @@ -0,0 +1,415 @@ +"""Simulation of readers and writers sharing ownership of a lock.""" + +import itertools +import json +import threading +import time +from collections import defaultdict +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from random import Random +from typing import Literal +from typing import Optional +from typing import Self +from typing import TypeAlias +from typing import Union + +import click +import pandas as pd + +from redis.client import Redis +from redis.exceptions import LockMaxWritersError +from redis.rwlock import RwLock + + +def _now() -> datetime: + return datetime.now() + + +AcquireStatus: TypeAlias = Union[ + Literal['success'], + Literal['timeout'], + Literal['aborted'], +] + + +@dataclass +class Parameters: + duration: float + seed: int + num_workers: int + wr_ratio: float + io_duration: float + max_writers: int + + def to_dict(self) -> dict: + return { + 'duration': self.duration, + 'seed': self.seed, + 'num_workers': self.num_workers, + 'wr_ratio': self.wr_ratio, + 'io_duration': self.io_duration, + 'max_writers': self.max_writers, + } + + def to_string(self) -> str: + return ', '.join([ + f'duration = {self.duration}', + f'seed = {self.seed}', + f'num_workers = {self.num_workers}', + f'wr_ratio = {self.wr_ratio}', + f'io_duration = {self.io_duration}', + f'max_writers = {self.max_writers}', + ]) + + +@dataclass +class InvocationMetric: + timestamp: Optional[datetime] = None + read_acquire_time: Optional[float] = None + read_acquire_status: Optional[AcquireStatus] = None + read_release_time: Optional[float] = None + write_acquire_time: Optional[float] = None + write_acquire_status: Optional[AcquireStatus] = None + write_release_time: Optional[float] = None + + +@dataclass +class TimeSeriesMetric: + timestamp: datetime + readers: int + writers: int + locked: int + + @staticmethod + def collect(lock: RwLock) -> Self: + return TimeSeriesMetric( + timestamp=_now(), + readers=int(lock.redis.zcard(lock._reader_lock_name())), + writers=int(lock.redis.zcard(lock._writer_semaphore_name())), + locked=int(lock.redis.get('locked') or 0), + ) + + def to_row(self) -> tuple[str, int, int, int]: + return (self.timestamp.isoformat(), self.readers, self.writers, self.locked) + + +class Worker: + # Keys used: + # - locked + # - readers + # - writers + # - total_reads + # - total_writes + + lock: RwLock + random: Random + wr_ratio: float + io_duration: float + metrics: list[InvocationMetric] + series: list[TimeSeriesMetric] + + def __init__( + self, + lock: RwLock, + params: Parameters, + ) -> None: + self.lock = lock + self.random = Random(params.seed or None) + self.wr_ratio = params.wr_ratio + self.io_duration = params.io_duration + self.metrics = [] + self.series = [] + + @property + def redis(self) -> Redis: + return self.lock.redis + + def rand_io_time(self) -> float: + if not self.io_duration: + return 0 + mean = self.io_duration + std = mean + shape = mean**2 / std**2 + scale = std**2 / mean + return self.random.gammavariate(shape, scale) + + def wait_for_io(self) -> None: + if self.io_duration: + time.sleep(self.rand_io_time()) + + def do_write(self, metric: InvocationMetric) -> None: + write_guard = self.lock.write() + + acquire_start = time.perf_counter() + try: + acquired = write_guard.acquire() + except LockMaxWritersError: + # Hit max workers; abort gracefully + metric.write_acquire_status = 'aborted' + return + metric.write_acquire_time = time.perf_counter() - acquire_start + self.series.append(TimeSeriesMetric.collect(self.lock)) + if not acquired: + metric.write_acquire_status = 'timeout' + return + + self.redis.set('locked', 1) + metric.write_acquire_status = 'success' + assert int(self.redis.get('readers') or 0) == 0 + + try: + self.wait_for_io() + self.redis.incr('total_reads') + finally: + self.redis.set('locked', 0) + release_start = time.perf_counter() + write_guard.release() + metric.write_release_time = time.perf_counter() - release_start + self.series.append(TimeSeriesMetric.collect(self.lock)) + + def do_read(self, metric: InvocationMetric) -> None: + read_guard = self.lock.read() + + acquire_start = time.perf_counter() + acquired = read_guard.acquire() + metric.read_acquire_time = time.perf_counter() - acquire_start + self.series.append(TimeSeriesMetric.collect(self.lock)) + if not acquired: + metric.read_acquire_status = 'timeout' + return + metric.read_acquire_status = 'success' + + assert not int(self.redis.get('locked') or 0) + + self.redis.incr('readers') + + try: + self.wait_for_io() + self.redis.incr('total_writes') + finally: + self.redis.decr('readers') + release_start = time.perf_counter() + read_guard.release() + metric.read_release_time = time.perf_counter() - release_start + self.series.append(TimeSeriesMetric.collect(self.lock)) + + def work(self) -> None: + metric = InvocationMetric() + + if self.random.random() < self.wr_ratio: + self.do_write(metric) + else: + self.do_read(metric) + + metric.timestamp = _now() + self.metrics.append(metric) + + def loop(self, stop_at: float) -> None: + while time.time() < stop_at: + self.work() + + +def plot_series(path: str): + """Import and run this inside a notebook to visualize time series.""" + import matplotlib.pyplot as plt + + data = json.load(open(path)) + + # Detect sweep parameters + values: defaultdict[str, set[Union[float, int]]] = defaultdict(lambda: set()) + params: dict[str, Union[float, int]] + for ts in data: + params = ts['params'] + for k, v in params.items(): + values[k].add(v) + sweep_params = [k for k, v in values.items() if len(v) > 1] + + for ts in data: + params = ts['params'] + df = pd.DataFrame( + ts['data'], + columns=['timestamp', 'readers', 'writers', 'locked'], + ) + df['timestamp'] = pd.to_datetime(df['timestamp'], format='ISO8601') + df.sort_values('timestamp', inplace=True) + fig, ax = plt.subplots(figsize=(10, 4)) + ax.plot(df['timestamp'], df['readers'], label='readers') + ax.plot(df['timestamp'], df['writers'], label='writers') + ax.plot(df['timestamp'], df['locked'], label='locked') + title = ', '.join([ + f'{k}={params[k]}' + for k in sweep_params + ]) + ax.set_title(title) + ax.set_xlabel('Time') + ax.set_ylabel('Count') + ax.legend() + ax.grid(alpha=0.3) + plt.show() + + +def display_metrics(invocation_metrics: list[InvocationMetric]) -> None: + inv_df = pd.DataFrame.from_records([ + { + 'timestamp': metric.timestamp.isoformat() if metric.timestamp else None, + 'read_acquire_time': metric.read_acquire_time, + 'read_release_time': metric.read_release_time, + 'write_acquire_time': metric.write_acquire_time, + 'write_release_time': metric.write_release_time, + 'read_acquire_status': metric.read_acquire_status, + 'write_acquire_status': metric.write_acquire_status, + } + for metric in invocation_metrics + ]) + metric_columns = [ + 'read_acquire_time', + 'read_release_time', + 'write_acquire_time', + 'write_release_time', + ] + + stats_df = pd.DataFrame(index=metric_columns) + inv_df[metric_columns] = inv_df[metric_columns].apply(pd.to_numeric, errors='coerce') + stats_df['min'] = inv_df[metric_columns].min() + stats_df['mean'] = inv_df[metric_columns].mean() + stats_df['p95'] = inv_df[metric_columns].quantile(0.95) + stats_df['max'] = inv_df[metric_columns].max() + + cols = ('read_acquire_status', 'write_acquire_status') + percentages = {} + for col in cols: + mask = inv_df[col].notna() + percentages[col] = inv_df[mask][col].value_counts() + status_df = pd.DataFrame(percentages).T.fillna(0) + status_df = status_df.reindex(columns=['success', 'timeout', 'aborted'], fill_value=0) + + print(stats_df.to_string(float_format=lambda x: f'{1e3 * x:.2f}ms')) + print(status_df.to_string(float_format=lambda x: f'{x:.0f}')) + print() + + +@dataclass +class Output: + params: Parameters + total_reads: int + total_writes: int + invocation_metrics: list[InvocationMetric] + time_series_metrics: list[TimeSeriesMetric] + + +def run(redis: Redis, params: Parameters) -> Output: + lock = RwLock( + redis=redis, + prefix='lock', + timeout=10, + sleep=max(params.io_duration, 1e-3), + blocking_timeout=1, + max_writers=params.max_writers, + ) + + stop_at = time.time() + params.duration + + # Spawn workers + workers = [Worker(lock, params) for _ in range(params.num_workers)] + threads = [ + threading.Thread(target=worker.loop, args=(stop_at,), daemon=True) for worker in workers + ] + for thread in threads: + thread.start() + + # Gather series metrics + time_series = [] + while time.time() < stop_at: + time_series.append(TimeSeriesMetric.collect(lock)) + time.sleep(0.01) + + # Wait for workers + for thread in threads: + thread.join() + + # Aggregate metrics + for worker in workers: + time_series.extend(worker.series) + invocation_metrics = [metric for worker in workers for metric in worker.metrics] + + total_reads = int(redis.get('total_reads') or 0) + total_writes = int(redis.get('total_writes') or 0) + + return Output( + params=params, + total_reads=total_reads, + total_writes=total_writes, + invocation_metrics=invocation_metrics, + time_series_metrics=time_series, + ) + + +@click.command() +@click.option('--duration', type=str, default='5') +@click.option('--seed', type=str, default='0') +@click.option('--num-workers', type=str, default='1') +@click.option('--wr-ratio', type=str, default='0.1') +@click.option('--io-duration', type=str, default='0.025') +@click.option('--max-writers', type=str, default='0') +def benchmark( + duration: str, + seed: str, + num_workers: str, + wr_ratio: str, + io_duration: str, + max_writers: str, +): + def parse_int(arg: str) -> list[int]: + return list([int(x) for x in arg.split(',')]) + + def parse_float(arg: str) -> list[float]: + return list([float(x) for x in arg.split(',')]) + + duration = parse_int(duration) + seed = parse_int(seed) + num_workers = parse_int(num_workers) + wr_ratio = parse_float(wr_ratio) + io_duration = parse_float(io_duration) + max_writers = parse_int(max_writers) + + time_series = [] + + for opts in itertools.product( + duration, + seed, + num_workers, + wr_ratio, + io_duration, + max_writers, + ): + redis = Redis(db=11) + redis.flushdb() + + params = Parameters(*opts) + output = run(redis, params) + + time_series.append({ + 'params': output.params.to_dict(), + 'data': [ts.to_row() for ts in output.time_series_metrics], + }) + + # Print stats + reads = output.total_reads + writes = output.total_writes + print(params.to_string()) + print(f'iops: {(reads + writes) / params.duration:.2f}') + + # Display metrics + display_metrics(output.invocation_metrics) + + out_dir = Path('.cache') + out_dir.mkdir(exist_ok=True) + out_path = out_dir / 'rwlock.json' + out_path = open(out_path, 'w') + json.dump(time_series, out_path) + + +if __name__ == '__main__': + benchmark() diff --git a/dev_requirements.txt b/dev_requirements.txt index 848d6207c4..cdfd5ffed9 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,8 +1,10 @@ build click==8.0.4 invoke==2.2.0 +matplotlib mock packaging>=20.4 +pandas pytest pytest-asyncio>=0.23.0 pytest-cov diff --git a/docs/index.rst b/docs/index.rst index 2c0557cbbe..af2909cddc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -69,6 +69,7 @@ Module Documentation exceptions backoff lock + rwlock retry lua_scripting opentelemetry diff --git a/docs/rwlock.rst b/docs/rwlock.rst new file mode 100644 index 0000000000..a35153ef72 --- /dev/null +++ b/docs/rwlock.rst @@ -0,0 +1,5 @@ +Reader-writer lock +######### + +.. automodule:: redis.rwlock + :members: \ No newline at end of file diff --git a/redis/exceptions.py b/redis/exceptions.py index 643444986b..3d66feba11 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -94,6 +94,12 @@ class LockNotOwnedError(LockError): pass +class LockMaxWritersError(LockError): + "Error trying to acquire a lock that has reached the max number of writers." + + pass + + class ChildDeadlockedError(Exception): "Error indicating that a child process is deadlocked after a fork()" diff --git a/redis/rwlock.py b/redis/rwlock.py new file mode 100644 index 0000000000..757682e831 --- /dev/null +++ b/redis/rwlock.py @@ -0,0 +1,534 @@ +from __future__ import annotations + +import abc +import time as mod_time +from types import TracebackType +from typing import TYPE_CHECKING +from typing import Any +from typing import Optional +from typing import Self +from typing import Type +from uuid import uuid4 + +from typing_extensions import override + +from redis.exceptions import LockError +from redis.exceptions import LockMaxWritersError +from redis.exceptions import LockNotOwnedError +from redis.typing import Number + +if TYPE_CHECKING: + from redis import Redis + + +# RwLock implementation +# ===================== +# +# The lock owns three keys: +# - `:write`: If this exists, holds the token of the user that +# owns the exclusive write lock. +# - `:write_semaphore`: Semaphore tracking writers that are +# waiting to acquire the lock. If any writers are waiting, readers +# block. +# - `:read`: Another semaphore, tracks readers. +# +# Semaphores are implemented as ordered sets, where score is the +# expiration time of the semaphore lease (Redis instance time). Expired +# leases are pruned before attempting to acquire the lock. +# +# We can't use built-in key expiration because individual set members +# cannot have an expiration. We can't create keys dynamically because +# this may break multi-node compatibility. +# +# The write-acquire script is careful to ensure that the writer waiting +# semaphore is only held if the caller is actually blocking; otherwise +# the caller adds contention for no reason. + + +class RwLock: + """A shared reader-writer lock. + + Unlike a standard mutex, a reader-writer lock allows multiple + readers to concurrently access the locked resource. However, writers + will wait for exclusive access to the locked resource. Writers get + priority when waiting on the lock so that readers do not starve + waiting writers. Writers are allowed to starve readers, however. + + This type of unfair lock is effective for scenarios where reads are + frequent and writes are infrequent. Because this lock relies on busy + waiting, it can be wasteful to use if your critical sections are + long and frequent. + + This lock is not fault-tolerant in a multi-node Redis setup. When a + master fails and data is lost, writer exclusivity may be violated. + In a single-node setup, the lock is sound. + + This lock is not re-entrant; attempting to acquire it twice in the + same thread may cause a deadlock until the blocking timeout ends. + """ + + lua_acquire_reader = None + lua_acquire_writer = None + lua_release_writer = None + lua_reacquire_reader = None + lua_reacquire_writer = None + + # KEYS[1] - writer lock name + # KEYS[2] - writer semaphore name + # KEYS[3] - reader lock name + # ARGV[1] - token + # ARGV[2] - expiration + # return 1 if the lock was acquired, otherwise 0 + LUA_ACQUIRE_READER_SCRIPT = """ + local token = ARGV[1] + + local timespec = redis.call('time') + local time = timespec[1] + 1e-6 * timespec[2] + + redis.call('zremrangebyscore', KEYS[2], 0, time) + redis.call('zremrangebyscore', KEYS[3], 0, time) + + local locked = redis.call('exists', KEYS[1]) > 0 or + redis.call('zcard', KEYS[2]) > 0 + if locked then + return 0 + end + + local expiry = time + ARGV[2] + redis.call('zadd', KEYS[3], expiry, token) + + return 1 + """ + + # KEYS[1] - writer lock name + # KEYS[2] - writer semaphore name + # KEYS[3] - reader lock name + # ARGV[1] - token + # ARGV[2] - expiration + # ARGV[3] - sempahore expiration (or 0 to release the sempahore) + # ARGV[4] - max writers + # + # NOTE: return codes: + # - 0: Lock was acquired + # - 1: Blocked + # - 2: Didn't block + LUA_ACQUIRE_WRITER_SCRIPT = """ + local writer_key = KEYS[1] + local semaphore_key = KEYS[2] + local reader_key = KEYS[3] + local token = ARGV[1] + local expiration = tonumber(ARGV[2]) + local semaphore_expiration = tonumber(ARGV[3]) + local max_writers = tonumber(ARGV[4]) + + local timespec = redis.call('time') + local time = timespec[1] + 1e-6 * timespec[2] + redis.call('zremrangebyscore', KEYS[2], 0, time) + redis.call('zremrangebyscore', KEYS[3], 0, time) + + local read_locked = redis.call('zcard', reader_key) > 0 + local write_locked = redis.call('exists', writer_key) > 0 + if read_locked or write_locked then + local op = 'create' + if semaphore_expiration == 0 then + op = 'delete' + elseif max_writers > 0 then + local count = redis.call('zcard', semaphore_key) + if write_locked then + count = count + 1 + end + if count >= max_writers then + op = 'update' + end + end + + local blocked + if op == 'update' then + blocked = redis.call('zadd', semaphore_key, 'XX', 'CH', time + semaphore_expiration, token) > 0 + elseif op == 'create' then + redis.call('zadd', semaphore_key, time + semaphore_expiration, token) + blocked = true + else + redis.call('zrem', semaphore_key, token) + blocked = false + end + + if blocked then + return 1 + else + return 2 + end + end + + redis.call('zrem', semaphore_key, token) + redis.call('set', writer_key, token, 'PX', math.floor(1000 * expiration)) + + return 0 + """ + + # KEYS[1] - writer lock name + # ARGV[1] - token + # return 1 if the lock was released, otherwise 0 + LUA_RELEASE_WRITER_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('del', KEYS[1]) + return 1 + """ + + # KEYS[1] - reader lock name + # ARGV[1] - token + # ARGV[2] - expiration + # return 1 if the lock was reacquired, otherwise 0 + LUA_REACQUIRE_READER_SCRIPT = """ + local token = ARGV[1] + + local timespec = redis.call('time') + local time = timespec[1] + 1e-6 * timespec[2] + local expiry = time + ARGV[2] + return redis.call('zadd', KEYS[1], 'XX', 'CH', expiry, token) > 0 + """ + + # KEYS[1] - writer lock name + # ARGV[1] - token + # ARGV[2] - expiration + # return 1 if the lock was reacquired, otherwise 0 + LUA_REACQUIRE_WRITER_SCRIPT = """ + local token = ARGV[1] + local value = redis.call('get', KEYS[1]) + if value == token then + redis.call('pexpire', KEYS[1], math.floor(1000 * ARGV[2])) + return 1 + else + return 0 + end + """ + + def __init__( + self, + redis: Redis, + prefix: str, + timeout: Number, + sleep: Number = 0.1, + blocking: bool = True, + blocking_timeout: Optional[Number] = None, + max_writers: Optional[Number] = None, + ) -> None: + """Construct a new lock. + + The `RwLock` object only identifies a lock; use ``read`` or + ``write`` to construct a guard that can be waited on to acquire + the lock. + + ``prefix``: The prefix to use for keys created by this lock. All + keys beginning with this prefix should be treated as *reserved* + by the lock. + + ``timeout``: The expiration on leases held by readers and + writers. This prevents deadlocks but may lead to unsynchronized + access if the timeout elapses while a process believes it is + still holding the lock. Use `reacquire()` to refresh the lease + when holding the lock for long periods of time. + + ``sleep``: The time in seconds to wait between attempts to + acquire the lock while spinning. + + ``blocking``: If `True`, then attempting to acquire the lock + will block the current thread if the lock is contended. + + ``blocking_timeout``: If ``blocking`` is `True`, the maximum + amount of time to wait for the lock to be released when + contended. If `None`, blocks forever. + + ``max_writers``: If set to a positive number, the maximum number + of writers that may contend the lock. For example, if set to 1, + then only one writer may hold or wait on the lock at a time. Any + additional writers will fail to acquire the lock and raise + ``LockMaxWritersError``. + """ + self.redis = redis + self.prefix = prefix + if timeout <= 1e-3: + raise ValueError('Timeout must be at least 1ms') + self.timeout = timeout + self.sleep = sleep + self.blocking = blocking + self.blocking_timeout = blocking_timeout + self.max_writers = max_writers or 0 + self._register_scripts() + + def _register_scripts(self) -> None: + cls = self.__class__ + client = self.redis + if cls.lua_acquire_reader is None: + cls.lua_acquire_reader = client.register_script(cls.LUA_ACQUIRE_READER_SCRIPT) + if cls.lua_acquire_writer is None: + cls.lua_acquire_writer = client.register_script(cls.LUA_ACQUIRE_WRITER_SCRIPT) + if cls.lua_release_writer is None: + cls.lua_release_writer = client.register_script(cls.LUA_RELEASE_WRITER_SCRIPT) + if cls.lua_reacquire_reader is None: + cls.lua_reacquire_reader = client.register_script(cls.LUA_REACQUIRE_READER_SCRIPT) + if cls.lua_reacquire_writer is None: + cls.lua_reacquire_writer = client.register_script(cls.LUA_REACQUIRE_WRITER_SCRIPT) + + def _reader_lock_name(self) -> str: + return f'{self.prefix}:read' + + def _writer_lock_name(self) -> str: + return f'{self.prefix}:write' + + def _writer_semaphore_name(self) -> str: + return f'{self.prefix}:write_semaphore' + + def _make_token(self) -> Any: + token = self.redis.get_encoder().encode(uuid4().hex) + return token + + def read( + self, + timeout: Optional[Number] = None, + sleep: Optional[Number] = None, + blocking: Optional[bool] = None, + blocking_timeout: Optional[Number] = None, + ) -> ReadLockGuard: + """Construct a guard that can be used to acquire the lock in + shared write mode. + + See ``RwLock`` for documentation on parameters. + """ + return ReadLockGuard( + lock=self, + token=self._make_token(), + timeout=timeout if timeout is not None else self.timeout, + sleep=sleep if sleep is not None else self.sleep, + blocking=blocking if blocking is not None else self.blocking, + blocking_timeout=blocking_timeout if blocking_timeout is not None else self.blocking_timeout, + ) + + def write( + self, + timeout: Optional[Number] = None, + sleep: Optional[Number] = None, + blocking: Optional[bool] = None, + blocking_timeout: Optional[Number] = None, + ) -> WriteLockGuard: + """Construct a guard that can be used to acquire the lock in + exclusive write mode. + + See ``RwLock`` for documentation on other parameters. + """ + return WriteLockGuard( + lock=self, + token=self._make_token(), + timeout=timeout if timeout is not None else self.timeout, + sleep=sleep if sleep is not None else self.sleep, + blocking=blocking if blocking is not None else self.blocking, + blocking_timeout=blocking_timeout if blocking_timeout is not None else self.blocking_timeout, + ) + + def is_write_locked(self) -> bool: + """Returns `True` if the lock is currently held by any writer.""" + return bool(self.redis.exists(self._writer_lock_name())) + + +class BaseLockGuard(abc.ABC): + lock: RwLock + + token: Any + + timeout: Number + sleep: Number + blocking: bool + blocking_timeout: Optional[Number] + + def __init__( + self, + lock: RwLock, + token: Any, + timeout: Number, + sleep: Number, + blocking: bool, + blocking_timeout: Optional[Number], + ) -> None: + self.lock = lock + self.token = token + if timeout <= 1e-3: + raise ValueError('Timeout must be at least 1ms') + self.timeout = timeout + self.sleep = sleep + self.blocking = blocking + self.blocking_timeout = blocking_timeout + + def __enter__(self) -> Self: + if not self.acquire(): + raise LockError( + "Unable to acquire lock within the time specified", + lock_name=self.lock.prefix, + ) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.release() + + @property + def redis(self) -> Redis: + return self.lock.redis + + @abc.abstractmethod + def _acquire(self, block_readers: bool) -> bool: + ... + + @abc.abstractmethod + def _release(self) -> bool: + ... + + @abc.abstractmethod + def _reacquire(self, timeout: Number) -> bool: + ... + + def acquire( + self, + sleep: Optional[Number] = None, + blocking: Optional[bool] = None, + blocking_timeout: Optional[Number] = None, + ) -> bool: + """Attempts to acquire the lock. + + See ``RwLock`` for documentation on parameters. + """ + if sleep is None: + sleep = self.sleep + if blocking is None: + blocking = self.blocking + if blocking_timeout is None: + blocking_timeout = self.blocking_timeout + stop_trying_at = None + if blocking_timeout is not None: + stop_trying_at = mod_time.monotonic() + blocking_timeout + while True: + next_try_at = mod_time.monotonic() + sleep + stop_trying = not blocking or ( + stop_trying_at is not None and next_try_at > stop_trying_at + ) + if self._acquire(should_block=not stop_trying): + return True + if stop_trying: + return False + mod_time.sleep(sleep) + + def release(self) -> None: + """Releases the lease on the lock. + + Throws ``LockNotOwnedError`` if the lock is no longer owned at + the time of release. + """ + if not self._release(): + raise LockNotOwnedError( + "Cannot release a lock that's no longer owned", + lock_name=self.lock.prefix, + ) + + def reacquire(self, timeout: Optional[Number] = None) -> bool: + """Resets the TTL of an already acquired lease back to a timeout + value. + + When holding the lock for a long amount of time, call this periodically + to ensure the lease does not expire. + """ + if timeout is None: + timeout = self.timeout + if not self._reacquire(timeout): + raise LockNotOwnedError( + "Cannot reacquire a lock that's no longer owned", + lock_name=self.lock.prefix, + ) + return True + + +class ReadLockGuard(BaseLockGuard): + """A lock guard that will acquire a shared read-mode lease on a + lock. + """ + + @override + def _acquire(self, should_block: bool) -> bool: + return bool(self.lock.lua_acquire_reader( + keys=[ + self.lock._writer_lock_name(), + self.lock._writer_semaphore_name(), + self.lock._reader_lock_name(), + ], + args=[self.token, self.timeout], + client=self.redis, + )) + + @override + def _release(self) -> bool: + return self.redis.zrem(self.lock._reader_lock_name(), self.token) + + @override + def _reacquire(self, timeout: Number) -> bool: + result = self.lock.lua_reacquire_reader( + keys=[self.lock._reader_lock_name()], + args=[self.token, timeout], + client=self.redis, + ) + return bool(result) + + +class WriteLockGuard(BaseLockGuard): + """A lock guard that will acquire an exclusive write-mode lease on a + lock. + """ + + @property + def _semaphore_timeout(self) -> Optional[Number]: + # Block readers just long enough to get through a sleep cycle + return 1.1 * self.sleep + + @override + def _acquire(self, should_block: bool): + code = self.lock.lua_acquire_writer( + keys=[ + self.lock._writer_lock_name(), + self.lock._writer_semaphore_name(), + self.lock._reader_lock_name(), + ], + args=[ + self.token, + # Lock timeout + self.timeout, + # Writer semaphore timeout + self._semaphore_timeout if should_block else 0, + # Max writers + self.lock.max_writers, + ], + client=self.redis, + ) + if should_block and code == 2: + raise LockMaxWritersError + else: + return code == 0 + + @override + def _release(self) -> bool: + return bool(self.lock.lua_release_writer( + keys=[self.lock._writer_lock_name()], + args=[self.token], + client=self.redis, + )) + + @override + def _reacquire(self, timeout: Number) -> bool: + return bool(self.lock.lua_reacquire_writer( + keys=[self.lock._writer_lock_name()], + args=[self.token, timeout], + client=self.redis, + )) diff --git a/tests/test_rwlock.py b/tests/test_rwlock.py new file mode 100644 index 0000000000..d4739f92d0 --- /dev/null +++ b/tests/test_rwlock.py @@ -0,0 +1,282 @@ +import threading +import time as mod_time + +import pytest + +from redis.client import Redis +from redis.exceptions import LockError +from redis.exceptions import LockMaxWritersError +from redis.exceptions import LockNotOwnedError +from redis.rwlock import RwLock + + +class TestLock: + def test_write_lock(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard = lock.write() + assert guard.acquire(blocking=False) + assert r.get('foo:write') == guard.token + assert r.ttl('foo:write') == 10 + guard.release() + assert r.get('foo:write') is None + + def test_read_lock(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard = lock.read() + assert guard.acquire(blocking=False) + score = r.zmscore('foo:read', [guard.token])[0] + expected = mod_time.time() + 10 + assert 0.999 * expected < score and score < expected + guard.release() + assert r.zcard('foo:read') == 0 + + def test_multiple_readers(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard1 = lock.read() + guard2 = lock.read() + assert guard1.acquire() + assert guard2.acquire() + assert r.zcard('foo:read') == 2 + guard1.release() + assert r.zcard('foo:read') == 1 + guard2.release() + assert r.zcard('foo:read') == 0 + + def test_mutual_exclusion(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard1 = lock.write() + guard2 = lock.write() + assert guard1.acquire() + assert not guard2.acquire(blocking=False) + guard1.release() + assert guard2.acquire(blocking=False) + guard2.release() + + def test_writer_reader_exclusion(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + wguard = lock.write() + rguard = lock.read() + assert wguard.acquire() + assert not rguard.acquire(blocking=False) + wguard.release() + assert rguard.acquire(blocking=False) + rguard.release() + + def test_reader_writer_exclusion(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + wguard = lock.write() + rguard = lock.read() + assert rguard.acquire() + assert not wguard.acquire(blocking=False) + rguard.release() + assert wguard.acquire(blocking=False) + wguard.release() + + def test_context_manager(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + with lock.write() as guard: + assert r.get('foo:write') == guard.token + assert r.get('foo:write') is None + with lock.read(): + assert r.zcard('foo:read') == 1 + assert r.zcard('foo:read') == 0 + + def test_reader_blocking_timeout(self, r: Redis): + lock = RwLock( + r, + 'foo', + timeout=10, + sleep=0.01, + blocking=True, + blocking_timeout=0.1, + ) + with lock.write(): + start = mod_time.time() + assert not lock.read().acquire() + end = mod_time.time() + assert end - start > lock.blocking_timeout - lock.sleep + + def test_writer_blocking_timeout(self, r: Redis): + lock = RwLock( + r, + 'foo', + timeout=10, + blocking=True, + sleep=0.01, + blocking_timeout=0.1, + ) + with lock.read(): + start = mod_time.time() + assert not lock.write().acquire() + end = mod_time.time() + assert end - start > lock.blocking_timeout - lock.sleep + + def test_writer_waiting_blocks_readers(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + r.zadd('foo:write_semaphore', {'token': mod_time.time() + 10}) + assert not lock.read().acquire(blocking=False) + + def test_read_timeout_then_write(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.03) + rguard = lock.read() + wguard = lock.write() + assert rguard.acquire(blocking=False) + assert r.zcard('foo:read') == 1 + assert not wguard.acquire(blocking=False) + mod_time.sleep(0.03) + assert wguard.acquire(blocking=False) + assert r.zcard('foo:read') == 0 + + def test_read_timeout_then_read(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.03) + guard1 = lock.read() + assert guard1.acquire(blocking=False) + assert r.zcard('foo:read') == 1 + mod_time.sleep(0.03) + guard2 = lock.read() + assert guard2.acquire(blocking=False) + assert r.zcard('foo:read') == 1 + + def test_write_waiting_then_read(self, r: Redis): + lock = RwLock(r, 'foo', timeout=1) + r.zadd('foo:write_semaphore', {'token': mod_time.time()}) + assert lock.read().acquire(blocking=False) + + def test_write_waiting_then_write(self, r: Redis): + lock = RwLock(r, 'foo', timeout=1) + r.zadd('foo:write_semaphore', {'token': mod_time.time()}) + assert lock.write().acquire(blocking=False) + + def test_write_reacquire(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1) + guard = lock.write() + guard.acquire() + mod_time.sleep(0.01) + ttl = r.pttl('foo:write') + assert 89 <= ttl and ttl <= 91 + guard.reacquire() + ttl = r.pttl('foo:write') + assert ttl >= 99 + + def test_write_reacquire_failure(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard1 = lock.write() + guard1.acquire() + r.delete('foo:write') + with pytest.raises(LockNotOwnedError): + guard1.reacquire() + guard2 = lock.write() + guard2.acquire() + with pytest.raises(LockNotOwnedError): + guard1.reacquire() + + def test_read_reacquire(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1) + guard = lock.read() + guard.acquire() + + mod_time.sleep(10e-3) + expiry = r.zmscore('foo:read', [guard.token])[0] + now = mod_time.time() + assert 89e-3 <= expiry - now and expiry - now < 91e-3 + + guard.reacquire() + expiry = r.zmscore('foo:read', [guard.token])[0] + now = mod_time.time() + assert 99e-3 <= expiry - now and expiry - now < 100e-3 + + def test_read_release_without_acquire_raises(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard = lock.read() + with pytest.raises(LockNotOwnedError): + guard.release() + + def test_read_release_no_longer_owned_raises(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1) + guard = lock.read() + assert guard.acquire(blocking=False) + r.zrem(lock._reader_lock_name(), guard.token) + with pytest.raises(LockNotOwnedError): + guard.release() + + def test_write_release_without_acquire_raises(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard = lock.write() + with pytest.raises(LockNotOwnedError): + guard.release() + + def test_write_release_no_longer_owned_raises(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1) + guard = lock.write() + assert guard.acquire(blocking=False) + r.set(lock._writer_lock_name(), b'other-token') + with pytest.raises(LockNotOwnedError): + guard.release() + + def test_read_reacquire_after_expiration(self, r: Redis): + """Can reacquire a read lock after expiration of the key still""" + lock = RwLock(r, 'foo', timeout=0.01) + guard = lock.read() + assert guard.acquire(blocking=False) + mod_time.sleep(0.01) + assert guard.reacquire() + + def test_read_reacquire_no_longer_owned_raises(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1) + guard = lock.read() + assert guard.acquire(blocking=False) + r.zrem(lock._reader_lock_name(), guard.token) + with pytest.raises(LockNotOwnedError): + guard.reacquire() + + def test_read_context_manager_blocking_timeout(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1, sleep=0.01) + wguard = lock.write() + assert wguard.acquire(blocking=False) + with pytest.raises(LockError) as excinfo: + with lock.read(blocking_timeout=0.05, sleep=0.01): + pass + assert excinfo.value.lock_name == 'foo' + wguard.release() + + def test_write_context_manager_blocking_timeout(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1, sleep=0.01) + rguard = lock.read() + assert rguard.acquire(blocking=False) + with pytest.raises(LockError): + with lock.write(blocking_timeout=0.05, sleep=0.01): + pass + rguard.release() + + def test_unique_writer(self, r: Redis): + lock = RwLock(r, 'foo', timeout=1, max_writers=1) + guard1 = lock.write() + assert guard1.acquire() + guard2 = lock.write() + with pytest.raises(LockMaxWritersError): + guard2.acquire() + + def test_max_writers(self, r: Redis): + lock = RwLock(r, 'foo', timeout=1, blocking_timeout=50e-3, sleep=10e-3, max_writers=2) + guard1 = lock.write() + assert guard1.acquire() + result = [] + + # Spawn a thread to wait on the lock + def target(): + with lock.write(): + result.append(1) + thread = threading.Thread(target=target) + thread.start() + + # Third writer fails + guard2 = lock.write() + with pytest.raises(LockMaxWritersError): + guard2.acquire() + + guard1.release() + thread.join() + + # Thread acquired after guard1 was released + assert len(result) == 1 + assert not lock.is_write_locked()