From 1492bb1251a6d1ead58e926c81d4390109d7eb60 Mon Sep 17 00:00:00 2001 From: Matthew McAllister Date: Sat, 27 Sep 2025 11:03:44 -0700 Subject: [PATCH 1/8] rwlock: Add RwLock implementation and tests --- redis/rwlock.py | 514 +++++++++++++++++++++++++++++++++++++++++++ tests/test_rwlock.py | 247 +++++++++++++++++++++ 2 files changed, 761 insertions(+) create mode 100644 redis/rwlock.py create mode 100644 tests/test_rwlock.py diff --git a/redis/rwlock.py b/redis/rwlock.py new file mode 100644 index 0000000000..c8f35e36ce --- /dev/null +++ b/redis/rwlock.py @@ -0,0 +1,514 @@ +from __future__ import annotations + +import abc +import time as mod_time +from types import TracebackType +from typing import TYPE_CHECKING +from typing import Any +from typing import Optional +from typing import Self +from typing import Type +from uuid import uuid1 + +from typing_extensions import override + +from redis.exceptions import LockError +from redis.exceptions import LockNotOwnedError +from redis.typing import Number + +if TYPE_CHECKING: + from redis import Redis + + +class CannotBlock(Exception): + pass + + +class RwLock: + """A shared reader-writer lock. + + Unlike a standard mutex, a reader-writer lock allows multiple + readers to concurrently access the locked resource. However, writers + will wait for exclusive access to the locked resource. Writers get + priority when waiting on the lock so that readers do not starve + waiting writers. Writers are allowed to starve readers, however. + + This type of lock is effective for scenarios where reads are + frequent and writes are infrequent. Because this lock relies on busy + waiting, it can be wasteful to use if your critical sections are + long and frequent. + + This lock is not fault-tolerant in a multi-node Redis setup. When a + master fails and data is lost, writer exclusivity may be violated. + In a single-node setup, the lock is sound. + """ + + lua_acquire_reader = None + lua_acquire_writer = None + lua_release_writer = None + lua_reacquire_reader = None + lua_reacquire_writer = None + + # KEYS[1] - writer lock name + # KEYS[2] - writer semaphore name + # KEYS[3] - reader lock name + # ARGV[1] - token + # ARGV[2] - expiration + # return 1 if the lock was acquired, otherwise 0 + LUA_ACQUIRE_READER_SCRIPT = """ + local token = ARGV[1] + + local timespec = redis.call('time') + local time = timespec[1] + 1e-6 * timespec[2] + + redis.call('zremrangebyscore', KEYS[2], 0, time) + redis.call('zremrangebyscore', KEYS[3], 0, time) + + local locked = redis.call('exists', KEYS[1]) > 0 or + redis.call('zcard', KEYS[2]) > 0 + if locked then + return 0 + end + + local expiry = time + ARGV[2] + redis.call('zadd', KEYS[3], expiry, token) + + return 1 + """ + + # KEYS[1] - writer lock name + # KEYS[2] - writer semaphore name + # KEYS[3] - reader lock name + # ARGV[1] - token + # ARGV[2] - expiration + # ARGV[3] - sempahore expiration + # ARGV[4] - max writers + # + # NOTE: return codes: + # - 0: Lock was acquired + # - 1: Blocked + # - 2: Didn't block + LUA_ACQUIRE_WRITER_SCRIPT = """ + local writer_key = KEYS[1] + local semaphore_key = KEYS[2] + local reader_key = KEYS[3] + local token = ARGV[1] + local expiration = tonumber(ARGV[2]) + local semaphore_expiration = tonumber(ARGV[3]) + local max_writers = tonumber(ARGV[4]) + + local timespec = redis.call('time') + local time = timespec[1] + 1e-6 * timespec[2] + redis.call('zremrangebyscore', KEYS[2], 0, time) + redis.call('zremrangebyscore', KEYS[3], 0, time) + + local read_locked = redis.call('zcard', reader_key) > 0 + local write_locked = redis.call('exists', writer_key) > 0 + if read_locked or write_locked then + local op = 'create' + if semaphore_expiration == 0 then + op = 'delete' + elseif max_writers > 0 then + local count = redis.call('zcard', semaphore_key) + if write_locked then + count = count + 1 + end + if count == max_writers then + op = 'update' + end + end + + local blocked + if op == 'update' then + blocked = redis.call('zadd', semaphore_key, 'XX', time + semaphore_expiration, token) > 0 + elseif op == 'create' then + redis.call('zadd', semaphore_key, time + semaphore_expiration, token) + blocked = true + else + redis.call('zrem', semaphore_key, token) + blocked = false + end + + if blocked then + return 1 + else + return 2 + end + end + + redis.call('zrem', semaphore_key, token) + redis.call('set', writer_key, token, 'PX', math.floor(1000 * expiration)) + + return 0 + """ + + # KEYS[1] - writer lock name + # ARGV[1] - token + # return 1 if the lock was released, otherwise 0 + LUA_RELEASE_WRITER_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('del', KEYS[1]) + return 1 + """ + + # KEYS[1] - reader lock name + # ARGV[1] - token + # ARGV[2] - expiration + # return 1 if the lock was reacquired, otherwise 0 + LUA_REACQUIRE_READER_SCRIPT = """ + local token = ARGV[1] + + local score = redis.call('zmscore', KEYS[1], token) + if not score[1] then + return 0 + end + + local timespec = redis.call('time') + local time = timespec[1] + 1e-6 * timespec[2] + local expiry = time + ARGV[2] + redis.call('zadd', KEYS[1], expiry, token) + return 1 + """ + + # KEYS[1] - writer lock name + # ARGV[1] - token + # ARGV[2] - expiration + # return 1 if the lock was reacquired, otherwise 0 + LUA_REACQUIRE_WRITER_SCRIPT = """ + local token = ARGV[1] + local value = redis.call('get', KEYS[1]) + if value == token then + redis.call('pexpire', KEYS[1], math.floor(1000 * ARGV[2])) + return 1 + else + return 0 + end + """ + + def __init__( + self, + redis: Redis, + prefix: str, + timeout: Number, + sleep: Number = 0.1, + blocking: bool = True, + blocking_timeout: Optional[Number] = None, + max_writers: Optional[Number] = None, + ) -> None: + """Construct a new lock. + + The `RwLock` object only identifies a lock; use ``read`` or + ``write`` to construct a guard that can be waited on to acquire + the lock. + + ``prefix``: The prefix to use for keys created by this lock. All + keys beginning with this prefix should be treated as *reserved* + by the lock. + + ``timeout``: The expiration on leases held by readers and + writers. This prevents deadlocks but may lead to unsynchronized + access if the timeout elapses while a process believes it is + still holding the lock. Use `reacquire()` to refresh the lease + when holding the lock for long periods of time. + + ``sleep``: The time in seconds to wait between attempts to + acquire the lock while spinning. + + ``blocking``: If `True`, then attempting to acquire the lock + will block the current thread if the lock is contended. + + ``blocking_timeout``: If ``blocking`` is `True`, the maximum + amount of time to wait for the lock to be released when + contended. If `None`, blocks forever. + + ``max_writers``: If set to a positive number, the maximum number + of writers that may contend the lock. For example, if set to 1, + then only one writer may hold or wait on the lock at a time. Any + additional writers will fail to acquire the lock. + """ + self.redis = redis + self.prefix = prefix + if timeout <= 1e-3: + raise ValueError('Timeout must be at least 1ms') + self.timeout = timeout + self.sleep = sleep + self.blocking = blocking + self.blocking_timeout = blocking_timeout + self.max_writers = max_writers or 0 + self._register_scripts() + + def _register_scripts(self) -> None: + cls = self.__class__ + client = self.redis + if cls.lua_acquire_reader is None: + cls.lua_acquire_reader = client.register_script(cls.LUA_ACQUIRE_READER_SCRIPT) + if cls.lua_acquire_writer is None: + cls.lua_acquire_writer = client.register_script(cls.LUA_ACQUIRE_WRITER_SCRIPT) + if cls.lua_release_writer is None: + cls.lua_release_writer = client.register_script(cls.LUA_RELEASE_WRITER_SCRIPT) + if cls.lua_reacquire_reader is None: + cls.lua_reacquire_reader = client.register_script(cls.LUA_REACQUIRE_READER_SCRIPT) + if cls.lua_reacquire_writer is None: + cls.lua_reacquire_writer = client.register_script(cls.LUA_REACQUIRE_WRITER_SCRIPT) + + def _reader_lock_name(self) -> str: + return f'{self.prefix}:read' + + def _writer_lock_name(self) -> str: + return f'{self.prefix}:write' + + def _writer_semaphore_name(self) -> str: + return f'{self.prefix}:write_semaphore' + + def _make_token(self) -> Any: + token = self.redis.get_encoder().encode(uuid1().hex) + return token + + def read( + self, + timeout: Optional[Number] = None, + sleep: Optional[Number] = None, + blocking: Optional[bool] = None, + blocking_timeout: Optional[Number] = None, + ) -> ReadLockGuard: + """Construct a guard that can be used to acquire the lock in + shared write mode. + + See ``RwLock`` for documentation on parameters. + """ + return ReadLockGuard( + lock=self, + token=self._make_token(), + timeout=timeout if timeout is not None else self.timeout, + sleep=sleep if sleep is not None else self.sleep, + blocking=blocking if blocking is not None else self.blocking, + blocking_timeout=blocking_timeout if blocking_timeout is not None else self.blocking_timeout, + ) + + def write( + self, + timeout: Optional[Number] = None, + sleep: Optional[Number] = None, + blocking: Optional[bool] = None, + blocking_timeout: Optional[Number] = None, + ) -> WriteLockGuard: + """Construct a guard that can be used to acquire the lock in + exclusive write mode. + + See ``RwLock`` for documentation on other parameters. + """ + return WriteLockGuard( + lock=self, + token=self._make_token(), + timeout=timeout if timeout is not None else self.timeout, + sleep=sleep if sleep is not None else self.sleep, + blocking=blocking if blocking is not None else self.blocking, + blocking_timeout=blocking_timeout if blocking_timeout is not None else self.blocking_timeout, + ) + + +class BaseLockGuard(abc.ABC): + lock: RwLock + + token: Any + + timeout: Number + sleep: Number + blocking: bool + blocking_timeout: Optional[Number] + + def __init__( + self, + lock: RwLock, + token: Any, + timeout: Number, + sleep: Number, + blocking: bool, + blocking_timeout: Optional[Number], + ) -> None: + self.lock = lock + self.token = token + if timeout <= 1e-3: + raise ValueError('Timeout must be at least 1ms') + self.timeout = timeout + self.sleep = sleep + self.blocking = blocking + self.blocking_timeout = blocking_timeout + + def __enter__(self) -> Self: + if not self.acquire(): + raise LockError( + "Unable to acquire lock within the time specified", + lock_name=self.lock.prefix, + ) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.release() + + @property + def redis(self) -> Redis: + return self.lock.redis + + @abc.abstractmethod + def _acquire(self, block_readers: bool) -> bool: + ... + + @abc.abstractmethod + def _release(self) -> bool: + ... + + @abc.abstractmethod + def _reacquire(self, timeout: Number) -> bool: + ... + + def acquire( + self, + sleep: Optional[Number] = None, + blocking: Optional[bool] = None, + blocking_timeout: Optional[Number] = None, + ) -> bool: + """Attempts to acquire the lock. + + See ``RwLock`` for documentation on parameters. + """ + if sleep is None: + sleep = self.sleep + if blocking is None: + blocking = self.blocking + if blocking_timeout is None: + blocking_timeout = self.blocking_timeout + stop_trying_at = None + if blocking_timeout is not None: + stop_trying_at = mod_time.monotonic() + blocking_timeout + while True: + next_try_at = mod_time.monotonic() + sleep + stop_trying = not blocking or ( + stop_trying_at is not None and next_try_at > stop_trying_at + ) + try: + if self._acquire(should_block=not stop_trying): + return True + except CannotBlock: + return False + if stop_trying: + return False + mod_time.sleep(sleep) + + def release(self) -> None: + """Releases the lease on the lock. + + Throws ``LockNotOwnedError`` if the lock is no longer owned at + the time of release. + """ + if not self._release(): + raise LockNotOwnedError( + "Cannot release a lock that's no longer owned", + lock_name=self.lock.prefix, + ) + + def reacquire(self, timeout: Optional[Number] = None) -> bool: + """Resets the TTL of an already acquired lease back to a timeout + value. + + When holding the lock for a long amount of time, call this periodically + to ensure the lease does not expire. + """ + if timeout is None: + timeout = self.timeout + if not self._reacquire(timeout): + raise LockNotOwnedError( + "Cannot reacquire a lock that's no longer owned", + lock_name=self.lock.prefix, + ) + return True + + +class ReadLockGuard(BaseLockGuard): + """A lock guard that will acquire a shared read-mode lease on a + lock. + """ + + @override + def _acquire(self, should_block: bool) -> bool: + return bool(self.lock.lua_acquire_reader( + keys=[ + self.lock._writer_lock_name(), + self.lock._writer_semaphore_name(), + self.lock._reader_lock_name(), + ], + args=[self.token, self.timeout], + client=self.redis, + )) + + @override + def _release(self) -> bool: + return self.redis.zrem(self.lock._reader_lock_name(), self.token) + + @override + def _reacquire(self, timeout: Number) -> bool: + result = self.lock.lua_reacquire_reader( + keys=[self.lock._reader_lock_name()], + args=[self.token, timeout], + client=self.redis, + ) + return bool(result) + + +class WriteLockGuard(BaseLockGuard): + """A lock guard that will acquire an exclusive write-mode lease on a + lock. + """ + + @property + def _semaphore_timeout(self) -> Optional[Number]: + # Block readers just long enough to get through a sleep cycle + return 1.1 * self.sleep + + @override + def _acquire(self, should_block: bool): + code = self.lock.lua_acquire_writer( + keys=[ + self.lock._writer_lock_name(), + self.lock._writer_semaphore_name(), + self.lock._reader_lock_name(), + ], + args=[ + self.token, + # Lock timeout + self.timeout, + # Writer semaphore timeout + self._semaphore_timeout if should_block else 0, + # Max writers + self.lock.max_writers, + ], + client=self.redis, + ) + if should_block and code == 2: + raise CannotBlock + else: + return code == 0 + + @override + def _release(self) -> bool: + return bool(self.lock.lua_release_writer( + keys=[self.lock._writer_lock_name()], + args=[self.token], + client=self.redis, + )) + + @override + def _reacquire(self, timeout: Number) -> bool: + return bool(self.lock.lua_reacquire_writer( + keys=[self.lock._writer_lock_name()], + args=[self.token, timeout], + client=self.redis, + )) diff --git a/tests/test_rwlock.py b/tests/test_rwlock.py new file mode 100644 index 0000000000..da6ce51b85 --- /dev/null +++ b/tests/test_rwlock.py @@ -0,0 +1,247 @@ +import time as mod_time + +import pytest + +from redis.client import Redis +from redis.exceptions import LockError +from redis.exceptions import LockNotOwnedError +from redis.rwlock import RwLock + + +class TestLock: + def test_write_lock(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard = lock.write() + assert guard.acquire(blocking=False) + assert r.get('foo:write') == guard.token + assert r.ttl('foo:write') == 10 + guard.release() + assert r.get('foo:write') is None + + def test_read_lock(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard = lock.read() + assert guard.acquire(blocking=False) + score = r.zmscore('foo:read', [guard.token])[0] + expected = mod_time.time() + 10 + assert 0.999 * expected < score and score < expected + guard.release() + assert r.zcard('foo:read') == 0 + + def test_multiple_readers(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard1 = lock.read() + guard2 = lock.read() + assert guard1.acquire() + assert guard2.acquire() + assert r.zcard('foo:read') == 2 + guard1.release() + assert r.zcard('foo:read') == 1 + guard2.release() + assert r.zcard('foo:read') == 0 + + def test_mutual_exclusion(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard1 = lock.write() + guard2 = lock.write() + assert guard1.acquire() + assert not guard2.acquire(blocking=False) + guard1.release() + assert guard2.acquire(blocking=False) + guard2.release() + + def test_writer_reader_exclusion(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + wguard = lock.write() + rguard = lock.read() + assert wguard.acquire() + assert not rguard.acquire(blocking=False) + wguard.release() + assert rguard.acquire(blocking=False) + rguard.release() + + def test_reader_writer_exclusion(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + wguard = lock.write() + rguard = lock.read() + assert rguard.acquire() + assert not wguard.acquire(blocking=False) + rguard.release() + assert wguard.acquire(blocking=False) + wguard.release() + + def test_context_manager(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + with lock.write() as guard: + assert r.get('foo:write') == guard.token + assert r.get('foo:write') is None + with lock.read(): + assert r.zcard('foo:read') == 1 + assert r.zcard('foo:read') == 0 + + def test_reader_blocking_timeout(self, r: Redis): + lock = RwLock( + r, + 'foo', + timeout=10, + sleep=0.01, + blocking=True, + blocking_timeout=0.1, + ) + with lock.write(): + start = mod_time.time() + assert not lock.read().acquire() + end = mod_time.time() + assert end - start > lock.blocking_timeout - lock.sleep + + def test_writer_blocking_timeout(self, r: Redis): + lock = RwLock( + r, + 'foo', + timeout=10, + blocking=True, + sleep=0.01, + blocking_timeout=0.1, + ) + with lock.read(): + start = mod_time.time() + assert not lock.write().acquire() + end = mod_time.time() + assert end - start > lock.blocking_timeout - lock.sleep + + def test_writer_waiting_blocks_readers(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + r.zadd('foo:write_semaphore', {'token': mod_time.time() + 10}) + assert not lock.read().acquire(blocking=False) + + def test_read_timeout_then_write(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.03) + rguard = lock.read() + wguard = lock.write() + assert rguard.acquire(blocking=False) + assert r.zcard('foo:read') == 1 + assert not wguard.acquire(blocking=False) + mod_time.sleep(0.03) + assert wguard.acquire(blocking=False) + assert r.zcard('foo:read') == 0 + + def test_read_timeout_then_read(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.03) + guard1 = lock.read() + assert guard1.acquire(blocking=False) + assert r.zcard('foo:read') == 1 + mod_time.sleep(0.03) + guard2 = lock.read() + assert guard2.acquire(blocking=False) + assert r.zcard('foo:read') == 1 + + def test_write_waiting_then_read(self, r: Redis): + lock = RwLock(r, 'foo', timeout=1) + r.zadd('foo:write_semaphore', {'token': mod_time.time()}) + assert lock.read().acquire(blocking=False) + + def test_write_waiting_then_write(self, r: Redis): + lock = RwLock(r, 'foo', timeout=1) + r.zadd('foo:write_semaphore', {'token': mod_time.time()}) + assert lock.write().acquire(blocking=False) + + def test_write_reacquire(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1) + guard = lock.write() + guard.acquire() + mod_time.sleep(0.01) + ttl = r.pttl('foo:write') + assert 89 <= ttl and ttl <= 91 + guard.reacquire() + ttl = r.pttl('foo:write') + assert ttl >= 99 + + def test_write_reacquire_failure(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard1 = lock.write() + guard1.acquire() + r.delete('foo:write') + with pytest.raises(LockNotOwnedError): + guard1.reacquire() + guard2 = lock.write() + guard2.acquire() + with pytest.raises(LockNotOwnedError): + guard1.reacquire() + + def test_read_reacquire(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1) + guard = lock.read() + guard.acquire() + + mod_time.sleep(10e-3) + expiry = r.zmscore('foo:read', [guard.token])[0] + now = mod_time.time() + assert 89e-3 <= expiry - now and expiry - now < 91e-3 + + guard.reacquire() + expiry = r.zmscore('foo:read', [guard.token])[0] + now = mod_time.time() + assert 99e-3 <= expiry - now and expiry - now < 100e-3 + + def test_read_release_without_acquire_raises(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard = lock.read() + with pytest.raises(LockNotOwnedError): + guard.release() + + def test_read_release_no_longer_owned_raises(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1) + guard = lock.read() + assert guard.acquire(blocking=False) + r.zrem(lock._reader_lock_name(), guard.token) + with pytest.raises(LockNotOwnedError): + guard.release() + + def test_write_release_without_acquire_raises(self, r: Redis): + lock = RwLock(r, 'foo', timeout=10) + guard = lock.write() + with pytest.raises(LockNotOwnedError): + guard.release() + + def test_write_release_no_longer_owned_raises(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1) + guard = lock.write() + assert guard.acquire(blocking=False) + r.set(lock._writer_lock_name(), b'other-token') + with pytest.raises(LockNotOwnedError): + guard.release() + + def test_read_reacquire_after_expiration(self, r: Redis): + """Can reacquire a read lock after expiration of the key still""" + lock = RwLock(r, 'foo', timeout=0.01) + guard = lock.read() + assert guard.acquire(blocking=False) + mod_time.sleep(0.01) + assert guard.reacquire() + + def test_read_reacquire_no_longer_owned_raises(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1) + guard = lock.read() + assert guard.acquire(blocking=False) + r.zrem(lock._reader_lock_name(), guard.token) + with pytest.raises(LockNotOwnedError): + guard.reacquire() + + def test_read_context_manager_blocking_timeout(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1, sleep=0.01) + wguard = lock.write() + assert wguard.acquire(blocking=False) + with pytest.raises(LockError) as excinfo: + with lock.read(blocking_timeout=0.05, sleep=0.01): + pass + assert excinfo.value.lock_name == 'foo' + wguard.release() + + def test_write_context_manager_blocking_timeout(self, r: Redis): + lock = RwLock(r, 'foo', timeout=0.1, sleep=0.01) + rguard = lock.read() + assert rguard.acquire(blocking=False) + with pytest.raises(LockError): + with lock.write(blocking_timeout=0.05, sleep=0.01): + pass + rguard.release() From a9274851f5384e4afe7230c0d7cde891ed646643 Mon Sep 17 00:00:00 2001 From: Matthew McAllister Date: Wed, 1 Oct 2025 14:33:30 -0700 Subject: [PATCH 2/8] rwlock: Add benchmark, fix bugs --- benchmarks/rwlock_cache.py | 337 +++++++++++++++++++++++++++++++++++++ dev_requirements.txt | 1 + redis/exceptions.py | 6 + redis/rwlock.py | 33 ++-- tests/test_rwlock.py | 35 ++++ 5 files changed, 392 insertions(+), 20 deletions(-) create mode 100644 benchmarks/rwlock_cache.py diff --git a/benchmarks/rwlock_cache.py b/benchmarks/rwlock_cache.py new file mode 100644 index 0000000000..7177d212b7 --- /dev/null +++ b/benchmarks/rwlock_cache.py @@ -0,0 +1,337 @@ +"""Simulation of a pool of workers reading a cached value which +occasionally must be replaced when it expires. +""" + +import random +import threading +import time +import uuid +from dataclasses import dataclass +from datetime import datetime +from numbers import Number +from pathlib import Path +from typing import Literal +from typing import Optional +from typing import Self +from typing import TextIO +from typing import TypeAlias +from typing import Union + +import pandas as pd + +from redis.client import Redis +from redis.exceptions import LockMaxWritersError +from redis.rwlock import RwLock + + +def _now() -> datetime: + return datetime.now() + + +AcquireStatus: TypeAlias = Union[ + Literal['success'], + Literal['timeout'], + Literal['aborted'], +] + + +@dataclass +class InvocationMetric: + timestamp: Optional[datetime] = None + read_acquire_time: Optional[float] = None + read_acquire_status: Optional[AcquireStatus] = None + read_release_time: Optional[float] = None + write_acquire_time: Optional[float] = None + write_acquire_status: Optional[AcquireStatus] = None + write_release_time: Optional[float] = None + + +@dataclass +class TimeSeriesMetric: + timestamp: datetime + num_readers: int + num_waiting_writers: int + + @staticmethod + def collect(lock: RwLock) -> Self: + metric = TimeSeriesMetric( + timestamp=_now(), + num_readers=lock.redis.zcard(lock._reader_lock_name()), + num_waiting_writers=lock.redis.zcard(lock._writer_semaphore_name()), + ) + assert metric.num_waiting_writers <= 1 + return metric + + +class Worker: + # Keys used: + # - worker_invocations: Total worker invocations + # - current_key: Holds the read counter key. Value is a random + # UUID4. Expires after `ttl` seconds. + # - previous_key: Previous value of current_key. Does not expire. + # - total: Total of all increments to the read counter. Should equal + # worker_invocations at the end. + + lock: RwLock + ttl: float + io_time: float + metrics: list[InvocationMetric] + series: list[TimeSeriesMetric] + + def __init__( + self, + lock: RwLock, + ttl: float, + io_time: float, + ) -> None: + self.lock = lock + self.ttl = ttl + self.io_time = io_time + self.metrics = [] + self.series = [] + + @property + def redis(self) -> Redis: + return self.lock.redis + + def rand_io_time(self) -> float: + mean = self.io_time + std = mean + shape = mean**2 / std**2 + scale = std**2 / mean + return random.gammavariate(shape, scale) + + def replace_key(self, metric: InvocationMetric) -> None: + write_guard = self.lock.write() + + # Acquire lock for writing + acquire_start = time.perf_counter() + try: + acquired = write_guard.acquire() + except LockMaxWritersError: + # Another worker has the lock; abort + metric.write_acquire_status = 'aborted' + return + metric.write_acquire_time = time.perf_counter() - acquire_start + self.series.append(TimeSeriesMetric.collect(self.lock)) + if not acquired: + metric.write_acquire_status = 'timeout' + return + + metric.write_acquire_status = 'success' + + def release() -> None: + release_start = time.perf_counter() + write_guard.release() + metric.write_release_time = time.perf_counter() - release_start + self.series.append(TimeSeriesMetric.collect(self.lock)) + + if self.redis.exists('current_key'): + release() + return + + # Update total with writes to previous key + previous_key: bytes = self.redis.get('previous_key') + if previous_key: + previous_value = self.redis.get(previous_key) + if previous_value: + self.redis.incrby('total', int(previous_value)) + + # Pretend to do I/O + time.sleep(self.rand_io_time()) + + # Update keys + new_key = f'cache:{uuid.uuid4().hex}' + self.redis.set('current_key', new_key, px=int(self.ttl * 1000)) + self.redis.set('previous_key', new_key) + + release() + + def work_inner(self, metric: InvocationMetric) -> None: + read_guard = self.lock.read() + + # Acquire lock for reading + acquire_start = time.perf_counter() + acquired = read_guard.acquire() + metric.read_acquire_time = time.perf_counter() - acquire_start + self.series.append(TimeSeriesMetric.collect(self.lock)) + if not acquired: + metric.read_acquire_status = 'timeout' + return + metric.read_acquire_status = 'success' + + def release() -> None: + release_start = time.perf_counter() + read_guard.release() + metric.read_release_time = time.perf_counter() - release_start + self.series.append(TimeSeriesMetric.collect(self.lock)) + + current_key = self.redis.get('current_key') + if current_key: + # Key exists; simulate I/O and bump counters + time.sleep(self.rand_io_time()) + + self.redis.incr(current_key) + self.redis.incr('worker_invocations') + + release() + else: + # Key does not exist; release lock and try to update key + release() + self.replace_key(metric) + + def work(self) -> None: + metric = InvocationMetric() + self.work_inner(metric) + metric.timestamp = _now() + self.metrics.append(metric) + + def loop(self, stop_at: float) -> None: + while time.time() < stop_at: + self.work() + + +def write_headers(csv_file: TextIO) -> None: + headers = [ + 'timestamp', + 'num_readers', + 'num_waiting_writers', + 'num_workers', + 'ttl', + ] + df = pd.DataFrame(columns=headers) + df.to_csv(csv_file, mode='w', header=True, index=False) + + +def write_time_series( + csv_file: TextIO, + n: int, + ttl: Number, + time_series: list[TimeSeriesMetric], +) -> None: + ts_records = [ + { + 'timestamp': metric.timestamp.isoformat(), + 'num_readers': metric.num_readers, + 'num_waiting_writers': metric.num_waiting_writers, + 'num_workers': n, + 'ttl': ttl, + } + for metric in time_series + ] + ts_df = pd.DataFrame(ts_records) + ts_df.to_csv(csv_file, mode='a', header=False, index=False) + + +def display_metrics( + n: int, + ttl: Number, + invocation_metrics: list[InvocationMetric], +) -> None: + inv_df = pd.DataFrame.from_records([ + { + 'timestamp': metric.timestamp.isoformat() if metric.timestamp else None, + 'read_acquire_time': metric.read_acquire_time, + 'read_release_time': metric.read_release_time, + 'write_acquire_time': metric.write_acquire_time, + 'write_release_time': metric.write_release_time, + 'read_acquire_status': metric.read_acquire_status, + 'write_acquire_status': metric.write_acquire_status, + } + for metric in invocation_metrics + ]) + metric_columns = [ + 'read_acquire_time', + 'read_release_time', + 'write_acquire_time', + 'write_release_time', + ] + + stats_df = pd.DataFrame(index=metric_columns) + inv_df[metric_columns] = inv_df[metric_columns].apply(pd.to_numeric, errors='coerce') + stats_df['min'] = inv_df[metric_columns].min() + stats_df['mean'] = inv_df[metric_columns].mean() + stats_df['p95'] = inv_df[metric_columns].quantile(0.95) + stats_df['max'] = inv_df[metric_columns].max() + + cols = ('read_acquire_status', 'write_acquire_status') + percentages = {} + for col in cols: + mask = inv_df[col].notna() + percentages[col] = inv_df[mask][col].value_counts() + status_df = pd.DataFrame(percentages).T.fillna(0) + status_df = status_df.reindex(columns=['success', 'timeout', 'aborted'], fill_value=0) + + print(stats_df.to_string(float_format=lambda x: f'{1e3 * x:.2f}ms')) + print(status_df.to_string(float_format=lambda x: f'{x:.0f}')) + print() + + +def main() -> None: + num_workers = [1, 2, 4, 8] + ttl_values = [0.05, 0.1, 0.25, 0.5, 1] + duration = 5 + io_time = 0.025 + cache_dir = Path('.cache') + cache_dir.mkdir(exist_ok=True) + csv_path = cache_dir / 'rwlock_cache.csv' + csv_file = open(csv_path, 'w') + write_headers(csv_file) + + for n in num_workers: + for ttl in ttl_values: + redis = Redis(db=11) + redis.flushdb() + + lock = RwLock( + redis=redis, + prefix='lock', + timeout=10, + sleep=io_time, + blocking_timeout=1, + max_writers=1, + ) + + stop_at = time.time() + duration + + # Spawn workers + workers = [Worker(lock=lock, ttl=ttl, io_time=io_time) for _ in range(n)] + threads = [ + threading.Thread(target=worker.loop, args=(stop_at,), daemon=True) for worker in workers + ] + for thread in threads: + thread.start() + + # Gather series metrics + time_series = [] + while time.time() < stop_at: + time_series.append(TimeSeriesMetric.collect(lock)) + time.sleep(0.01) + + # Wait for workers + for thread in threads: + thread.join() + + # Verify that total == # invocations + total = int(redis.get('total') or 0) + total += int(redis.get(redis.get('previous_key')) or 0) + worker_invocations = int(redis.get('worker_invocations') or 0) + assert worker_invocations == total + + # Write time series data + for worker in workers: + time_series.extend(worker.series) + write_time_series(csv_file, n, ttl, time_series) + + # Print stats + print(f'n = {n}, ttl = {ttl}') + writes = len(redis.keys('cache:*')) + print(f'iops: {(writes + worker_invocations) / duration:.2f}') + + # Display metrics + invocation_metrics = [metric for worker in workers for metric in worker.metrics] + display_metrics(n, ttl, invocation_metrics) + + +if __name__ == '__main__': + main() diff --git a/dev_requirements.txt b/dev_requirements.txt index 848d6207c4..b9ed5eb470 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -3,6 +3,7 @@ click==8.0.4 invoke==2.2.0 mock packaging>=20.4 +pandas pytest pytest-asyncio>=0.23.0 pytest-cov diff --git a/redis/exceptions.py b/redis/exceptions.py index 643444986b..3d66feba11 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -94,6 +94,12 @@ class LockNotOwnedError(LockError): pass +class LockMaxWritersError(LockError): + "Error trying to acquire a lock that has reached the max number of writers." + + pass + + class ChildDeadlockedError(Exception): "Error indicating that a child process is deadlocked after a fork()" diff --git a/redis/rwlock.py b/redis/rwlock.py index c8f35e36ce..6bb6dbcf33 100644 --- a/redis/rwlock.py +++ b/redis/rwlock.py @@ -13,6 +13,7 @@ from typing_extensions import override from redis.exceptions import LockError +from redis.exceptions import LockMaxWritersError from redis.exceptions import LockNotOwnedError from redis.typing import Number @@ -20,10 +21,6 @@ from redis import Redis -class CannotBlock(Exception): - pass - - class RwLock: """A shared reader-writer lock. @@ -113,14 +110,14 @@ class RwLock: if write_locked then count = count + 1 end - if count == max_writers then + if count >= max_writers then op = 'update' end end local blocked if op == 'update' then - blocked = redis.call('zadd', semaphore_key, 'XX', time + semaphore_expiration, token) > 0 + blocked = redis.call('zadd', semaphore_key, 'XX', 'CH', time + semaphore_expiration, token) > 0 elseif op == 'create' then redis.call('zadd', semaphore_key, time + semaphore_expiration, token) blocked = true @@ -161,16 +158,10 @@ class RwLock: LUA_REACQUIRE_READER_SCRIPT = """ local token = ARGV[1] - local score = redis.call('zmscore', KEYS[1], token) - if not score[1] then - return 0 - end - local timespec = redis.call('time') local time = timespec[1] + 1e-6 * timespec[2] local expiry = time + ARGV[2] - redis.call('zadd', KEYS[1], expiry, token) - return 1 + return redis.call('zadd', KEYS[1], 'XX', 'CH', expiry, token) > 0 """ # KEYS[1] - writer lock name @@ -227,7 +218,8 @@ def __init__( ``max_writers``: If set to a positive number, the maximum number of writers that may contend the lock. For example, if set to 1, then only one writer may hold or wait on the lock at a time. Any - additional writers will fail to acquire the lock. + additional writers will fail to acquire the lock and raise + ``LockMaxWritersError``. """ self.redis = redis self.prefix = prefix @@ -309,6 +301,10 @@ def write( blocking_timeout=blocking_timeout if blocking_timeout is not None else self.blocking_timeout, ) + def is_write_locked(self) -> bool: + """Returns `True` if the lock is currently held by any writer.""" + return bool(self.redis.exists(self._writer_lock_name())) + class BaseLockGuard(abc.ABC): lock: RwLock @@ -394,11 +390,8 @@ def acquire( stop_trying = not blocking or ( stop_trying_at is not None and next_try_at > stop_trying_at ) - try: - if self._acquire(should_block=not stop_trying): - return True - except CannotBlock: - return False + if self._acquire(should_block=not stop_trying): + return True if stop_trying: return False mod_time.sleep(sleep) @@ -493,7 +486,7 @@ def _acquire(self, should_block: bool): client=self.redis, ) if should_block and code == 2: - raise CannotBlock + raise LockMaxWritersError else: return code == 0 diff --git a/tests/test_rwlock.py b/tests/test_rwlock.py index da6ce51b85..d4739f92d0 100644 --- a/tests/test_rwlock.py +++ b/tests/test_rwlock.py @@ -1,9 +1,11 @@ +import threading import time as mod_time import pytest from redis.client import Redis from redis.exceptions import LockError +from redis.exceptions import LockMaxWritersError from redis.exceptions import LockNotOwnedError from redis.rwlock import RwLock @@ -245,3 +247,36 @@ def test_write_context_manager_blocking_timeout(self, r: Redis): with lock.write(blocking_timeout=0.05, sleep=0.01): pass rguard.release() + + def test_unique_writer(self, r: Redis): + lock = RwLock(r, 'foo', timeout=1, max_writers=1) + guard1 = lock.write() + assert guard1.acquire() + guard2 = lock.write() + with pytest.raises(LockMaxWritersError): + guard2.acquire() + + def test_max_writers(self, r: Redis): + lock = RwLock(r, 'foo', timeout=1, blocking_timeout=50e-3, sleep=10e-3, max_writers=2) + guard1 = lock.write() + assert guard1.acquire() + result = [] + + # Spawn a thread to wait on the lock + def target(): + with lock.write(): + result.append(1) + thread = threading.Thread(target=target) + thread.start() + + # Third writer fails + guard2 = lock.write() + with pytest.raises(LockMaxWritersError): + guard2.acquire() + + guard1.release() + thread.join() + + # Thread acquired after guard1 was released + assert len(result) == 1 + assert not lock.is_write_locked() From cb66cfcf47cdd8e6388797becb3d3b795d96ce14 Mon Sep 17 00:00:00 2001 From: Matthew McAllister Date: Wed, 1 Oct 2025 18:09:25 -0700 Subject: [PATCH 3/8] rwlock: Add visualization code to benchmark Not strictly necessary to check in but anyone who wants to run the benchmark will appreciate it. --- benchmarks/rwlock_cache.py | 20 ++++++++++++++++++++ dev_requirements.txt | 1 + 2 files changed, 21 insertions(+) diff --git a/benchmarks/rwlock_cache.py b/benchmarks/rwlock_cache.py index 7177d212b7..da3e88debd 100644 --- a/benchmarks/rwlock_cache.py +++ b/benchmarks/rwlock_cache.py @@ -223,6 +223,26 @@ def write_time_series( ts_df.to_csv(csv_file, mode='a', header=False, index=False) +def plot_series(path: str): + """Import and run this inside a notebook to visualize time series.""" + import matplotlib.pyplot as plt + + df = pd.read_csv(path) + df['timestamp'] = pd.to_datetime(df['timestamp'], format='ISO8601') + + for (workers, ttl), group in df.groupby(['num_workers', 'ttl'], sort=True): + group = group.sort_values('timestamp') + fig, ax = plt.subplots(figsize=(10, 4)) + ax.plot(group['timestamp'], group['num_readers'], label='num_readers') + ax.plot(group['timestamp'], group['num_waiting_writers'], label='num_waiting_writers') + ax.set_title(f'num_workers={workers}, ttl={ttl}') + ax.set_xlabel('Time') + ax.set_ylabel('Count') + ax.legend() + ax.grid(alpha=0.3) + plt.show() + + def display_metrics( n: int, ttl: Number, diff --git a/dev_requirements.txt b/dev_requirements.txt index b9ed5eb470..cdfd5ffed9 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,6 +1,7 @@ build click==8.0.4 invoke==2.2.0 +matplotlib mock packaging>=20.4 pandas From 5be07563163231fd2efbcac44600ea3964d327db Mon Sep 17 00:00:00 2001 From: Matthew McAllister Date: Thu, 2 Oct 2025 12:20:39 -0700 Subject: [PATCH 4/8] rwlock: Add RwLock to docs --- docs/index.rst | 1 + docs/rwlock.rst | 5 +++++ 2 files changed, 6 insertions(+) create mode 100644 docs/rwlock.rst diff --git a/docs/index.rst b/docs/index.rst index 2c0557cbbe..af2909cddc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -69,6 +69,7 @@ Module Documentation exceptions backoff lock + rwlock retry lua_scripting opentelemetry diff --git a/docs/rwlock.rst b/docs/rwlock.rst new file mode 100644 index 0000000000..a35153ef72 --- /dev/null +++ b/docs/rwlock.rst @@ -0,0 +1,5 @@ +Reader-writer lock +######### + +.. automodule:: redis.rwlock + :members: \ No newline at end of file From 85d5a009370060f8fd8ceca34a20b802df9557be Mon Sep 17 00:00:00 2001 From: Matthew McAllister Date: Thu, 2 Oct 2025 13:05:19 -0700 Subject: [PATCH 5/8] rwlock: Add implementation explanation --- redis/rwlock.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/redis/rwlock.py b/redis/rwlock.py index 6bb6dbcf33..52b477b239 100644 --- a/redis/rwlock.py +++ b/redis/rwlock.py @@ -21,6 +21,30 @@ from redis import Redis +# RwLock implementation +# ===================== +# +# The lock owns three keys: +# - `:write`: If this exists, holds the token of the user that +# owns the exclusive write lock. +# - `:write_semaphore`: Semaphore tracking writers that are +# waiting to acquire the lock. If any writers are waiting, readers +# block. +# - `:read`: Another semaphore, tracks readers. +# +# Semaphores are implemented as ordered sets, where score is the +# expiration time of the semaphore lease (Redis instance time). Expired +# leases are pruned before attempting to acquire the lock. +# +# We can't use built-in key expiration because individual set members +# cannot have an expiration. We can't create keys dynamically because +# this may break multi-node compatibility. +# +# The write-acquire script is careful to ensure that the writer waiting +# semaphore is only held if the caller is actually blocking; otherwise +# the caller adds contention for no reason. + + class RwLock: """A shared reader-writer lock. @@ -30,7 +54,7 @@ class RwLock: priority when waiting on the lock so that readers do not starve waiting writers. Writers are allowed to starve readers, however. - This type of lock is effective for scenarios where reads are + This type of unfair lock is effective for scenarios where reads are frequent and writes are infrequent. Because this lock relies on busy waiting, it can be wasteful to use if your critical sections are long and frequent. @@ -38,6 +62,9 @@ class RwLock: This lock is not fault-tolerant in a multi-node Redis setup. When a master fails and data is lost, writer exclusivity may be violated. In a single-node setup, the lock is sound. + + This lock is not re-entrant; attempting to acquire it twice in the + same thread may cause a deadlock until the blocking timeout ends. """ lua_acquire_reader = None @@ -78,7 +105,7 @@ class RwLock: # KEYS[3] - reader lock name # ARGV[1] - token # ARGV[2] - expiration - # ARGV[3] - sempahore expiration + # ARGV[3] - sempahore expiration (or 0 to release the sempahore) # ARGV[4] - max writers # # NOTE: return codes: From fe2a5fad8934f1bc803fdc37e3acd6d0ca475698 Mon Sep 17 00:00:00 2001 From: Matthew McAllister Date: Thu, 2 Oct 2025 13:19:38 -0700 Subject: [PATCH 6/8] rwlock: Switch to uuid4 uuid1 just seems more liable to cause a collision for minimal benefit. --- redis/rwlock.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/rwlock.py b/redis/rwlock.py index 52b477b239..757682e831 100644 --- a/redis/rwlock.py +++ b/redis/rwlock.py @@ -8,7 +8,7 @@ from typing import Optional from typing import Self from typing import Type -from uuid import uuid1 +from uuid import uuid4 from typing_extensions import override @@ -283,7 +283,7 @@ def _writer_semaphore_name(self) -> str: return f'{self.prefix}:write_semaphore' def _make_token(self) -> Any: - token = self.redis.get_encoder().encode(uuid1().hex) + token = self.redis.get_encoder().encode(uuid4().hex) return token def read( From 88c9f4accceafa3bfb7d770382a9994d37bafa34 Mon Sep 17 00:00:00 2001 From: Matthew McAllister Date: Thu, 2 Oct 2025 13:37:31 -0700 Subject: [PATCH 7/8] rwlock: Elaborate on benchmark --- benchmarks/rwlock_cache.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/benchmarks/rwlock_cache.py b/benchmarks/rwlock_cache.py index da3e88debd..0e10e83fef 100644 --- a/benchmarks/rwlock_cache.py +++ b/benchmarks/rwlock_cache.py @@ -1,5 +1,12 @@ """Simulation of a pool of workers reading a cached value which occasionally must be replaced when it expires. + +The workers coordinate so that only a single writer takes responsibility +for updating the cached value when it expires. This minimizes contention +and latency when the cache expires. + +The simulation also attempts to detect if any workers read a stale value +by rotating out read counters when the cache key is updated. """ import random From 86ab631d13d8b7fcceea66b05055f87593cc064f Mon Sep 17 00:00:00 2001 From: Matthew McAllister Date: Mon, 6 Oct 2025 18:17:23 -0700 Subject: [PATCH 8/8] rwlock: Better benchmark --- benchmarks/rwlock.py | 415 +++++++++++++++++++++++++++++++++++++ benchmarks/rwlock_cache.py | 364 -------------------------------- 2 files changed, 415 insertions(+), 364 deletions(-) create mode 100644 benchmarks/rwlock.py delete mode 100644 benchmarks/rwlock_cache.py diff --git a/benchmarks/rwlock.py b/benchmarks/rwlock.py new file mode 100644 index 0000000000..499d021545 --- /dev/null +++ b/benchmarks/rwlock.py @@ -0,0 +1,415 @@ +"""Simulation of readers and writers sharing ownership of a lock.""" + +import itertools +import json +import threading +import time +from collections import defaultdict +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from random import Random +from typing import Literal +from typing import Optional +from typing import Self +from typing import TypeAlias +from typing import Union + +import click +import pandas as pd + +from redis.client import Redis +from redis.exceptions import LockMaxWritersError +from redis.rwlock import RwLock + + +def _now() -> datetime: + return datetime.now() + + +AcquireStatus: TypeAlias = Union[ + Literal['success'], + Literal['timeout'], + Literal['aborted'], +] + + +@dataclass +class Parameters: + duration: float + seed: int + num_workers: int + wr_ratio: float + io_duration: float + max_writers: int + + def to_dict(self) -> dict: + return { + 'duration': self.duration, + 'seed': self.seed, + 'num_workers': self.num_workers, + 'wr_ratio': self.wr_ratio, + 'io_duration': self.io_duration, + 'max_writers': self.max_writers, + } + + def to_string(self) -> str: + return ', '.join([ + f'duration = {self.duration}', + f'seed = {self.seed}', + f'num_workers = {self.num_workers}', + f'wr_ratio = {self.wr_ratio}', + f'io_duration = {self.io_duration}', + f'max_writers = {self.max_writers}', + ]) + + +@dataclass +class InvocationMetric: + timestamp: Optional[datetime] = None + read_acquire_time: Optional[float] = None + read_acquire_status: Optional[AcquireStatus] = None + read_release_time: Optional[float] = None + write_acquire_time: Optional[float] = None + write_acquire_status: Optional[AcquireStatus] = None + write_release_time: Optional[float] = None + + +@dataclass +class TimeSeriesMetric: + timestamp: datetime + readers: int + writers: int + locked: int + + @staticmethod + def collect(lock: RwLock) -> Self: + return TimeSeriesMetric( + timestamp=_now(), + readers=int(lock.redis.zcard(lock._reader_lock_name())), + writers=int(lock.redis.zcard(lock._writer_semaphore_name())), + locked=int(lock.redis.get('locked') or 0), + ) + + def to_row(self) -> tuple[str, int, int, int]: + return (self.timestamp.isoformat(), self.readers, self.writers, self.locked) + + +class Worker: + # Keys used: + # - locked + # - readers + # - writers + # - total_reads + # - total_writes + + lock: RwLock + random: Random + wr_ratio: float + io_duration: float + metrics: list[InvocationMetric] + series: list[TimeSeriesMetric] + + def __init__( + self, + lock: RwLock, + params: Parameters, + ) -> None: + self.lock = lock + self.random = Random(params.seed or None) + self.wr_ratio = params.wr_ratio + self.io_duration = params.io_duration + self.metrics = [] + self.series = [] + + @property + def redis(self) -> Redis: + return self.lock.redis + + def rand_io_time(self) -> float: + if not self.io_duration: + return 0 + mean = self.io_duration + std = mean + shape = mean**2 / std**2 + scale = std**2 / mean + return self.random.gammavariate(shape, scale) + + def wait_for_io(self) -> None: + if self.io_duration: + time.sleep(self.rand_io_time()) + + def do_write(self, metric: InvocationMetric) -> None: + write_guard = self.lock.write() + + acquire_start = time.perf_counter() + try: + acquired = write_guard.acquire() + except LockMaxWritersError: + # Hit max workers; abort gracefully + metric.write_acquire_status = 'aborted' + return + metric.write_acquire_time = time.perf_counter() - acquire_start + self.series.append(TimeSeriesMetric.collect(self.lock)) + if not acquired: + metric.write_acquire_status = 'timeout' + return + + self.redis.set('locked', 1) + metric.write_acquire_status = 'success' + assert int(self.redis.get('readers') or 0) == 0 + + try: + self.wait_for_io() + self.redis.incr('total_reads') + finally: + self.redis.set('locked', 0) + release_start = time.perf_counter() + write_guard.release() + metric.write_release_time = time.perf_counter() - release_start + self.series.append(TimeSeriesMetric.collect(self.lock)) + + def do_read(self, metric: InvocationMetric) -> None: + read_guard = self.lock.read() + + acquire_start = time.perf_counter() + acquired = read_guard.acquire() + metric.read_acquire_time = time.perf_counter() - acquire_start + self.series.append(TimeSeriesMetric.collect(self.lock)) + if not acquired: + metric.read_acquire_status = 'timeout' + return + metric.read_acquire_status = 'success' + + assert not int(self.redis.get('locked') or 0) + + self.redis.incr('readers') + + try: + self.wait_for_io() + self.redis.incr('total_writes') + finally: + self.redis.decr('readers') + release_start = time.perf_counter() + read_guard.release() + metric.read_release_time = time.perf_counter() - release_start + self.series.append(TimeSeriesMetric.collect(self.lock)) + + def work(self) -> None: + metric = InvocationMetric() + + if self.random.random() < self.wr_ratio: + self.do_write(metric) + else: + self.do_read(metric) + + metric.timestamp = _now() + self.metrics.append(metric) + + def loop(self, stop_at: float) -> None: + while time.time() < stop_at: + self.work() + + +def plot_series(path: str): + """Import and run this inside a notebook to visualize time series.""" + import matplotlib.pyplot as plt + + data = json.load(open(path)) + + # Detect sweep parameters + values: defaultdict[str, set[Union[float, int]]] = defaultdict(lambda: set()) + params: dict[str, Union[float, int]] + for ts in data: + params = ts['params'] + for k, v in params.items(): + values[k].add(v) + sweep_params = [k for k, v in values.items() if len(v) > 1] + + for ts in data: + params = ts['params'] + df = pd.DataFrame( + ts['data'], + columns=['timestamp', 'readers', 'writers', 'locked'], + ) + df['timestamp'] = pd.to_datetime(df['timestamp'], format='ISO8601') + df.sort_values('timestamp', inplace=True) + fig, ax = plt.subplots(figsize=(10, 4)) + ax.plot(df['timestamp'], df['readers'], label='readers') + ax.plot(df['timestamp'], df['writers'], label='writers') + ax.plot(df['timestamp'], df['locked'], label='locked') + title = ', '.join([ + f'{k}={params[k]}' + for k in sweep_params + ]) + ax.set_title(title) + ax.set_xlabel('Time') + ax.set_ylabel('Count') + ax.legend() + ax.grid(alpha=0.3) + plt.show() + + +def display_metrics(invocation_metrics: list[InvocationMetric]) -> None: + inv_df = pd.DataFrame.from_records([ + { + 'timestamp': metric.timestamp.isoformat() if metric.timestamp else None, + 'read_acquire_time': metric.read_acquire_time, + 'read_release_time': metric.read_release_time, + 'write_acquire_time': metric.write_acquire_time, + 'write_release_time': metric.write_release_time, + 'read_acquire_status': metric.read_acquire_status, + 'write_acquire_status': metric.write_acquire_status, + } + for metric in invocation_metrics + ]) + metric_columns = [ + 'read_acquire_time', + 'read_release_time', + 'write_acquire_time', + 'write_release_time', + ] + + stats_df = pd.DataFrame(index=metric_columns) + inv_df[metric_columns] = inv_df[metric_columns].apply(pd.to_numeric, errors='coerce') + stats_df['min'] = inv_df[metric_columns].min() + stats_df['mean'] = inv_df[metric_columns].mean() + stats_df['p95'] = inv_df[metric_columns].quantile(0.95) + stats_df['max'] = inv_df[metric_columns].max() + + cols = ('read_acquire_status', 'write_acquire_status') + percentages = {} + for col in cols: + mask = inv_df[col].notna() + percentages[col] = inv_df[mask][col].value_counts() + status_df = pd.DataFrame(percentages).T.fillna(0) + status_df = status_df.reindex(columns=['success', 'timeout', 'aborted'], fill_value=0) + + print(stats_df.to_string(float_format=lambda x: f'{1e3 * x:.2f}ms')) + print(status_df.to_string(float_format=lambda x: f'{x:.0f}')) + print() + + +@dataclass +class Output: + params: Parameters + total_reads: int + total_writes: int + invocation_metrics: list[InvocationMetric] + time_series_metrics: list[TimeSeriesMetric] + + +def run(redis: Redis, params: Parameters) -> Output: + lock = RwLock( + redis=redis, + prefix='lock', + timeout=10, + sleep=max(params.io_duration, 1e-3), + blocking_timeout=1, + max_writers=params.max_writers, + ) + + stop_at = time.time() + params.duration + + # Spawn workers + workers = [Worker(lock, params) for _ in range(params.num_workers)] + threads = [ + threading.Thread(target=worker.loop, args=(stop_at,), daemon=True) for worker in workers + ] + for thread in threads: + thread.start() + + # Gather series metrics + time_series = [] + while time.time() < stop_at: + time_series.append(TimeSeriesMetric.collect(lock)) + time.sleep(0.01) + + # Wait for workers + for thread in threads: + thread.join() + + # Aggregate metrics + for worker in workers: + time_series.extend(worker.series) + invocation_metrics = [metric for worker in workers for metric in worker.metrics] + + total_reads = int(redis.get('total_reads') or 0) + total_writes = int(redis.get('total_writes') or 0) + + return Output( + params=params, + total_reads=total_reads, + total_writes=total_writes, + invocation_metrics=invocation_metrics, + time_series_metrics=time_series, + ) + + +@click.command() +@click.option('--duration', type=str, default='5') +@click.option('--seed', type=str, default='0') +@click.option('--num-workers', type=str, default='1') +@click.option('--wr-ratio', type=str, default='0.1') +@click.option('--io-duration', type=str, default='0.025') +@click.option('--max-writers', type=str, default='0') +def benchmark( + duration: str, + seed: str, + num_workers: str, + wr_ratio: str, + io_duration: str, + max_writers: str, +): + def parse_int(arg: str) -> list[int]: + return list([int(x) for x in arg.split(',')]) + + def parse_float(arg: str) -> list[float]: + return list([float(x) for x in arg.split(',')]) + + duration = parse_int(duration) + seed = parse_int(seed) + num_workers = parse_int(num_workers) + wr_ratio = parse_float(wr_ratio) + io_duration = parse_float(io_duration) + max_writers = parse_int(max_writers) + + time_series = [] + + for opts in itertools.product( + duration, + seed, + num_workers, + wr_ratio, + io_duration, + max_writers, + ): + redis = Redis(db=11) + redis.flushdb() + + params = Parameters(*opts) + output = run(redis, params) + + time_series.append({ + 'params': output.params.to_dict(), + 'data': [ts.to_row() for ts in output.time_series_metrics], + }) + + # Print stats + reads = output.total_reads + writes = output.total_writes + print(params.to_string()) + print(f'iops: {(reads + writes) / params.duration:.2f}') + + # Display metrics + display_metrics(output.invocation_metrics) + + out_dir = Path('.cache') + out_dir.mkdir(exist_ok=True) + out_path = out_dir / 'rwlock.json' + out_path = open(out_path, 'w') + json.dump(time_series, out_path) + + +if __name__ == '__main__': + benchmark() diff --git a/benchmarks/rwlock_cache.py b/benchmarks/rwlock_cache.py deleted file mode 100644 index 0e10e83fef..0000000000 --- a/benchmarks/rwlock_cache.py +++ /dev/null @@ -1,364 +0,0 @@ -"""Simulation of a pool of workers reading a cached value which -occasionally must be replaced when it expires. - -The workers coordinate so that only a single writer takes responsibility -for updating the cached value when it expires. This minimizes contention -and latency when the cache expires. - -The simulation also attempts to detect if any workers read a stale value -by rotating out read counters when the cache key is updated. -""" - -import random -import threading -import time -import uuid -from dataclasses import dataclass -from datetime import datetime -from numbers import Number -from pathlib import Path -from typing import Literal -from typing import Optional -from typing import Self -from typing import TextIO -from typing import TypeAlias -from typing import Union - -import pandas as pd - -from redis.client import Redis -from redis.exceptions import LockMaxWritersError -from redis.rwlock import RwLock - - -def _now() -> datetime: - return datetime.now() - - -AcquireStatus: TypeAlias = Union[ - Literal['success'], - Literal['timeout'], - Literal['aborted'], -] - - -@dataclass -class InvocationMetric: - timestamp: Optional[datetime] = None - read_acquire_time: Optional[float] = None - read_acquire_status: Optional[AcquireStatus] = None - read_release_time: Optional[float] = None - write_acquire_time: Optional[float] = None - write_acquire_status: Optional[AcquireStatus] = None - write_release_time: Optional[float] = None - - -@dataclass -class TimeSeriesMetric: - timestamp: datetime - num_readers: int - num_waiting_writers: int - - @staticmethod - def collect(lock: RwLock) -> Self: - metric = TimeSeriesMetric( - timestamp=_now(), - num_readers=lock.redis.zcard(lock._reader_lock_name()), - num_waiting_writers=lock.redis.zcard(lock._writer_semaphore_name()), - ) - assert metric.num_waiting_writers <= 1 - return metric - - -class Worker: - # Keys used: - # - worker_invocations: Total worker invocations - # - current_key: Holds the read counter key. Value is a random - # UUID4. Expires after `ttl` seconds. - # - previous_key: Previous value of current_key. Does not expire. - # - total: Total of all increments to the read counter. Should equal - # worker_invocations at the end. - - lock: RwLock - ttl: float - io_time: float - metrics: list[InvocationMetric] - series: list[TimeSeriesMetric] - - def __init__( - self, - lock: RwLock, - ttl: float, - io_time: float, - ) -> None: - self.lock = lock - self.ttl = ttl - self.io_time = io_time - self.metrics = [] - self.series = [] - - @property - def redis(self) -> Redis: - return self.lock.redis - - def rand_io_time(self) -> float: - mean = self.io_time - std = mean - shape = mean**2 / std**2 - scale = std**2 / mean - return random.gammavariate(shape, scale) - - def replace_key(self, metric: InvocationMetric) -> None: - write_guard = self.lock.write() - - # Acquire lock for writing - acquire_start = time.perf_counter() - try: - acquired = write_guard.acquire() - except LockMaxWritersError: - # Another worker has the lock; abort - metric.write_acquire_status = 'aborted' - return - metric.write_acquire_time = time.perf_counter() - acquire_start - self.series.append(TimeSeriesMetric.collect(self.lock)) - if not acquired: - metric.write_acquire_status = 'timeout' - return - - metric.write_acquire_status = 'success' - - def release() -> None: - release_start = time.perf_counter() - write_guard.release() - metric.write_release_time = time.perf_counter() - release_start - self.series.append(TimeSeriesMetric.collect(self.lock)) - - if self.redis.exists('current_key'): - release() - return - - # Update total with writes to previous key - previous_key: bytes = self.redis.get('previous_key') - if previous_key: - previous_value = self.redis.get(previous_key) - if previous_value: - self.redis.incrby('total', int(previous_value)) - - # Pretend to do I/O - time.sleep(self.rand_io_time()) - - # Update keys - new_key = f'cache:{uuid.uuid4().hex}' - self.redis.set('current_key', new_key, px=int(self.ttl * 1000)) - self.redis.set('previous_key', new_key) - - release() - - def work_inner(self, metric: InvocationMetric) -> None: - read_guard = self.lock.read() - - # Acquire lock for reading - acquire_start = time.perf_counter() - acquired = read_guard.acquire() - metric.read_acquire_time = time.perf_counter() - acquire_start - self.series.append(TimeSeriesMetric.collect(self.lock)) - if not acquired: - metric.read_acquire_status = 'timeout' - return - metric.read_acquire_status = 'success' - - def release() -> None: - release_start = time.perf_counter() - read_guard.release() - metric.read_release_time = time.perf_counter() - release_start - self.series.append(TimeSeriesMetric.collect(self.lock)) - - current_key = self.redis.get('current_key') - if current_key: - # Key exists; simulate I/O and bump counters - time.sleep(self.rand_io_time()) - - self.redis.incr(current_key) - self.redis.incr('worker_invocations') - - release() - else: - # Key does not exist; release lock and try to update key - release() - self.replace_key(metric) - - def work(self) -> None: - metric = InvocationMetric() - self.work_inner(metric) - metric.timestamp = _now() - self.metrics.append(metric) - - def loop(self, stop_at: float) -> None: - while time.time() < stop_at: - self.work() - - -def write_headers(csv_file: TextIO) -> None: - headers = [ - 'timestamp', - 'num_readers', - 'num_waiting_writers', - 'num_workers', - 'ttl', - ] - df = pd.DataFrame(columns=headers) - df.to_csv(csv_file, mode='w', header=True, index=False) - - -def write_time_series( - csv_file: TextIO, - n: int, - ttl: Number, - time_series: list[TimeSeriesMetric], -) -> None: - ts_records = [ - { - 'timestamp': metric.timestamp.isoformat(), - 'num_readers': metric.num_readers, - 'num_waiting_writers': metric.num_waiting_writers, - 'num_workers': n, - 'ttl': ttl, - } - for metric in time_series - ] - ts_df = pd.DataFrame(ts_records) - ts_df.to_csv(csv_file, mode='a', header=False, index=False) - - -def plot_series(path: str): - """Import and run this inside a notebook to visualize time series.""" - import matplotlib.pyplot as plt - - df = pd.read_csv(path) - df['timestamp'] = pd.to_datetime(df['timestamp'], format='ISO8601') - - for (workers, ttl), group in df.groupby(['num_workers', 'ttl'], sort=True): - group = group.sort_values('timestamp') - fig, ax = plt.subplots(figsize=(10, 4)) - ax.plot(group['timestamp'], group['num_readers'], label='num_readers') - ax.plot(group['timestamp'], group['num_waiting_writers'], label='num_waiting_writers') - ax.set_title(f'num_workers={workers}, ttl={ttl}') - ax.set_xlabel('Time') - ax.set_ylabel('Count') - ax.legend() - ax.grid(alpha=0.3) - plt.show() - - -def display_metrics( - n: int, - ttl: Number, - invocation_metrics: list[InvocationMetric], -) -> None: - inv_df = pd.DataFrame.from_records([ - { - 'timestamp': metric.timestamp.isoformat() if metric.timestamp else None, - 'read_acquire_time': metric.read_acquire_time, - 'read_release_time': metric.read_release_time, - 'write_acquire_time': metric.write_acquire_time, - 'write_release_time': metric.write_release_time, - 'read_acquire_status': metric.read_acquire_status, - 'write_acquire_status': metric.write_acquire_status, - } - for metric in invocation_metrics - ]) - metric_columns = [ - 'read_acquire_time', - 'read_release_time', - 'write_acquire_time', - 'write_release_time', - ] - - stats_df = pd.DataFrame(index=metric_columns) - inv_df[metric_columns] = inv_df[metric_columns].apply(pd.to_numeric, errors='coerce') - stats_df['min'] = inv_df[metric_columns].min() - stats_df['mean'] = inv_df[metric_columns].mean() - stats_df['p95'] = inv_df[metric_columns].quantile(0.95) - stats_df['max'] = inv_df[metric_columns].max() - - cols = ('read_acquire_status', 'write_acquire_status') - percentages = {} - for col in cols: - mask = inv_df[col].notna() - percentages[col] = inv_df[mask][col].value_counts() - status_df = pd.DataFrame(percentages).T.fillna(0) - status_df = status_df.reindex(columns=['success', 'timeout', 'aborted'], fill_value=0) - - print(stats_df.to_string(float_format=lambda x: f'{1e3 * x:.2f}ms')) - print(status_df.to_string(float_format=lambda x: f'{x:.0f}')) - print() - - -def main() -> None: - num_workers = [1, 2, 4, 8] - ttl_values = [0.05, 0.1, 0.25, 0.5, 1] - duration = 5 - io_time = 0.025 - cache_dir = Path('.cache') - cache_dir.mkdir(exist_ok=True) - csv_path = cache_dir / 'rwlock_cache.csv' - csv_file = open(csv_path, 'w') - write_headers(csv_file) - - for n in num_workers: - for ttl in ttl_values: - redis = Redis(db=11) - redis.flushdb() - - lock = RwLock( - redis=redis, - prefix='lock', - timeout=10, - sleep=io_time, - blocking_timeout=1, - max_writers=1, - ) - - stop_at = time.time() + duration - - # Spawn workers - workers = [Worker(lock=lock, ttl=ttl, io_time=io_time) for _ in range(n)] - threads = [ - threading.Thread(target=worker.loop, args=(stop_at,), daemon=True) for worker in workers - ] - for thread in threads: - thread.start() - - # Gather series metrics - time_series = [] - while time.time() < stop_at: - time_series.append(TimeSeriesMetric.collect(lock)) - time.sleep(0.01) - - # Wait for workers - for thread in threads: - thread.join() - - # Verify that total == # invocations - total = int(redis.get('total') or 0) - total += int(redis.get(redis.get('previous_key')) or 0) - worker_invocations = int(redis.get('worker_invocations') or 0) - assert worker_invocations == total - - # Write time series data - for worker in workers: - time_series.extend(worker.series) - write_time_series(csv_file, n, ttl, time_series) - - # Print stats - print(f'n = {n}, ttl = {ttl}') - writes = len(redis.keys('cache:*')) - print(f'iops: {(writes + worker_invocations) / duration:.2f}') - - # Display metrics - invocation_metrics = [metric for worker in workers for metric in worker.metrics] - display_metrics(n, ttl, invocation_metrics) - - -if __name__ == '__main__': - main()