Skip to content

Commit 1615e35

Browse files
paultiqCopilot
andauthored
Fix Postgres Distributed rates (#248) (#250)
* fix: make put's atomic - lock, check, put... rather than check and put * Update postgres.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: Resolve comments - wrap _create_table in a pg_advisory_xact_lock and parameterize rate.interval * fix: ruff format * add log message for locknotavailable * fix: Check sliding windows for both rates --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 817d041 commit 1615e35

File tree

2 files changed

+215
-8
lines changed

2 files changed

+215
-8
lines changed

pyrate_limiter/buckets/postgres.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class Queries:
2626
CREATE_INDEX_ON_TIMESTAMP = """
2727
CREATE INDEX IF NOT EXISTS {index} ON {table} (item_timestamp)
2828
"""
29+
LOCK_TABLE = """
30+
LOCK TABLE {table} IN EXCLUSIVE MODE NOWAIT
31+
"""
2932
COUNT = """
3033
SELECT COUNT(*) FROM {table}
3134
"""
@@ -73,6 +76,8 @@ def _get_conn(self):
7376

7477
def _create_table(self):
7578
with self._get_conn() as conn:
79+
lock_id = hash(self._full_tbl) & 0x7FFFFFFF
80+
conn.execute("SELECT pg_advisory_xact_lock(%s)", (lock_id,))
7681
conn.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl))
7782
index_name = f"timestampIndex_{self.table}"
7883
conn.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name))
@@ -81,14 +86,38 @@ def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
8186
"""Put an item (typically the current time) in the bucket
8287
return true if successful, otherwise false
8388
"""
89+
from psycopg.errors import LockNotAvailable
90+
8491
if item.weight == 0:
8592
return True
8693

94+
item_ts_seconds = item.timestamp / 1000
95+
8796
with self._get_conn() as conn:
97+
# Acquire an EXCLUSIVE MODE lock on the bucket table using NOWAIT.
98+
# This ensures the "check current count" + "insert new items" sequence
99+
# is atomic with respect to other writers, so rate limits cannot be
100+
# exceeded due to concurrent requests interleaving.
101+
#
102+
# Because we use NOWAIT, if the table is already locked by another
103+
# transaction, PostgreSQL raises LockNotAvailable and we immediately
104+
# reject this request (return False) instead of blocking or retrying.
105+
# This provides predictable, fail-fast behavior but may limit
106+
# throughput under high contention since only one writer can perform
107+
# the check-and-put at a time.
108+
try:
109+
conn.execute(Queries.LOCK_TABLE.format(table=self._full_tbl))
110+
except LockNotAvailable:
111+
logger.debug("LockNotAvailable")
112+
self.failing_rate = self.rates[0]
113+
return False
114+
88115
for rate in self.rates:
89-
bound = f"SELECT TO_TIMESTAMP({item.timestamp / 1000}) - INTERVAL '{rate.interval} milliseconds'"
90-
query = f"SELECT COUNT(*) FROM {self._full_tbl} WHERE item_timestamp >= ({bound})" # noqa: S608 # FIXME: SQL Parameterization and table name sanitization
91-
cur = conn.execute(query)
116+
cur = conn.execute(
117+
f"SELECT COUNT(*) FROM {self._full_tbl} " # noqa: S608
118+
f"WHERE item_timestamp >= TO_TIMESTAMP(%s) - (%s * INTERVAL '1 milliseconds')",
119+
(item_ts_seconds, rate.interval),
120+
)
92121
count = int(cur.fetchone()[0])
93122
cur.close()
94123

@@ -97,13 +126,9 @@ def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
97126
return False
98127

99128
self.failing_rate = None
100-
101129
query = Queries.PUT.format(table=self._full_tbl)
102-
103-
# https://www.psycopg.org/docs/extras.html#fast-exec
104-
105130
for _ in range(item.weight):
106-
conn.execute(query, (item.name, item.weight, item.timestamp / 1000))
131+
conn.execute(query, (item.name, item.weight, item_ts_seconds))
107132

108133
return True
109134

tests/test_postgres_concurrent.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import threading
2+
3+
import pytest
4+
5+
from pyrate_limiter import Duration, PostgresBucket, Rate
6+
from pyrate_limiter.abstracts import RateItem
7+
8+
9+
@pytest.mark.postgres
10+
class TestPostgresConcurrent:
11+
12+
@pytest.fixture
13+
def pg_pool(self):
14+
from psycopg_pool import ConnectionPool
15+
16+
pool = ConnectionPool(
17+
"postgresql://postgres:postgres@localhost:5432",
18+
min_size=4,
19+
max_size=10,
20+
open=True,
21+
)
22+
yield pool
23+
pool.close()
24+
25+
@pytest.fixture
26+
def clean_table(self, pg_pool):
27+
from pyrate_limiter import id_generator
28+
29+
table = f"test_concurrent_{id_generator()}"
30+
yield table
31+
with pg_pool.connection() as conn:
32+
conn.execute(f"DROP TABLE IF EXISTS ratelimit___{table}")
33+
34+
def test_concurrent_put(self, pg_pool, clean_table):
35+
rate_limit = 5
36+
rates = [Rate(rate_limit, Duration.SECOND)]
37+
num_threads = 8
38+
attempts_per_thread = 10
39+
40+
results = []
41+
results_lock = threading.Lock()
42+
43+
def worker(thread_id: int):
44+
bucket = PostgresBucket(pg_pool, clean_table, rates)
45+
thread_results = []
46+
47+
for _ in range(attempts_per_thread):
48+
timestamp = bucket.now()
49+
item = RateItem(f"thread_{thread_id}", timestamp, weight=1)
50+
success = bucket.put(item)
51+
thread_results.append((timestamp, success))
52+
53+
with results_lock:
54+
results.extend(thread_results)
55+
56+
threads = [
57+
threading.Thread(target=worker, args=(i,)) for i in range(num_threads)
58+
]
59+
60+
for t in threads:
61+
t.start()
62+
for t in threads:
63+
t.join()
64+
65+
# Verify db state
66+
full_table = f"ratelimit___{clean_table}"
67+
with pg_pool.connection() as conn:
68+
cur = conn.execute(f"SELECT COUNT(*) FROM {full_table}") # noqa: S608
69+
total_in_db = cur.fetchone()[0]
70+
cur.close()
71+
72+
# Check in sliding windows to make sure rate didn't exceed
73+
cur = conn.execute(
74+
f"SELECT EXTRACT(EPOCH FROM item_timestamp)::bigint as ts FROM {full_table}" # noqa: S608
75+
)
76+
db_timestamps = [row[0] for row in cur.fetchall()]
77+
cur.close()
78+
79+
for ts in db_timestamps:
80+
window_start = ts - 1 # 1 second window
81+
count_in_window = sum(1 for t in db_timestamps if t > window_start and t <= ts)
82+
assert count_in_window <= rate_limit, (
83+
f"Rate limit exceeded in DB: {count_in_window} items in 1-second window ending at {ts}"
84+
)
85+
86+
# Verify anything worked
87+
total_success = sum(1 for _, success in results if success)
88+
assert total_success > 0, "No successful acquisitions"
89+
assert total_success == total_in_db, (
90+
f"Mismatch: {total_success} reported successes but {total_in_db} items in DB"
91+
)
92+
93+
# Verify some rejections
94+
total_rejected = sum(1 for _, success in results if not success)
95+
assert total_rejected > 0, (
96+
"No rejections occurred - rate limiting may not be working"
97+
)
98+
99+
def test_concurrent_put_multiple_rates(self, pg_pool, clean_table):
100+
rates = [
101+
Rate(3, 500), # 3 per 500ms
102+
Rate(5, 1000), # 5 per second
103+
]
104+
num_threads = 4
105+
attempts_per_thread = 5
106+
107+
results = []
108+
results_lock = threading.Lock()
109+
110+
def worker(thread_id: int):
111+
bucket = PostgresBucket(pg_pool, clean_table, rates)
112+
thread_results = []
113+
114+
for _ in range(attempts_per_thread):
115+
timestamp = bucket.now()
116+
item = RateItem(f"thread_{thread_id}", timestamp, weight=1)
117+
success = bucket.put(item)
118+
thread_results.append((timestamp, success))
119+
120+
with results_lock:
121+
results.extend(thread_results)
122+
123+
threads = [
124+
threading.Thread(target=worker, args=(i,)) for i in range(num_threads)
125+
]
126+
127+
for t in threads:
128+
t.start()
129+
for t in threads:
130+
t.join()
131+
132+
successful_timestamps = sorted([ts for ts, success in results if success])
133+
134+
# Check sliding windows for both rates
135+
for ts in successful_timestamps:
136+
# 1-second sliding window
137+
count_1s = sum(1 for t in successful_timestamps if ts - 1000 <= t <= ts)
138+
assert count_1s <= 5, f"1-second rate exceeded: {count_1s} items in window ending at {ts}"
139+
140+
# 500ms sliding window
141+
count_500ms = sum(1 for t in successful_timestamps if ts - 500 <= t <= ts)
142+
assert count_500ms <= 3, f"500ms rate exceeded: {count_500ms} items in window ending at {ts}"
143+
144+
def test_concurrent_put_weighted(self, pg_pool, clean_table):
145+
rate_limit = 10
146+
rates = [Rate(rate_limit, Duration.SECOND)]
147+
num_threads = 4
148+
weight = 3
149+
150+
results = []
151+
results_lock = threading.Lock()
152+
153+
def worker(thread_id: int):
154+
bucket = PostgresBucket(pg_pool, clean_table, rates)
155+
thread_results = []
156+
157+
for _ in range(5):
158+
timestamp = bucket.now()
159+
item = RateItem(f"thread_{thread_id}", timestamp, weight=weight)
160+
success = bucket.put(item)
161+
thread_results.append((timestamp, success, weight))
162+
163+
with results_lock:
164+
results.extend(thread_results)
165+
166+
threads = [
167+
threading.Thread(target=worker, args=(i,)) for i in range(num_threads)
168+
]
169+
170+
for t in threads:
171+
t.start()
172+
for t in threads:
173+
t.join()
174+
175+
successful_results = sorted([(ts, w) for ts, success, w in results if success])
176+
177+
# Check sliding windows for weighted items
178+
for ts, _ in successful_results:
179+
weight_in_window = sum(w for t, w in successful_results if ts - 1000 <= t <= ts)
180+
assert weight_in_window <= rate_limit, (
181+
f"Rate limit exceeded: weight {weight_in_window} in 1-second window ending at {ts}"
182+
)

0 commit comments

Comments
 (0)