|
1 | 1 | import asyncio |
2 | | -import math |
3 | 2 | import time |
| 3 | +import uuid |
| 4 | + |
| 5 | + |
| 6 | +class _QuerySet(object): |
| 7 | + def __init__(self): |
| 8 | + self._items = list() |
| 9 | + self._index = dict() |
| 10 | + |
| 11 | + def add(self, item: int) -> None: |
| 12 | + if item in self._index: |
| 13 | + return |
| 14 | + |
| 15 | + self._items.append(item) |
| 16 | + self._index[item] = len(self._items) - 1 |
| 17 | + |
| 18 | + def remove(self, item: int) -> None: |
| 19 | + if item not in self._index: |
| 20 | + return |
| 21 | + |
| 22 | + index = self._index[item] |
| 23 | + self._items[index] = self._items[-1] |
| 24 | + self._index[self._items[-1]] = index |
| 25 | + self._items.pop() |
| 26 | + del self._index[item] |
| 27 | + |
| 28 | + def query(self, item: int) -> int: |
| 29 | + return self._index[item] |
4 | 30 |
|
5 | 31 |
|
6 | 32 | class ModelBreaker(object): |
7 | 33 | def __init__(self): |
8 | 34 | # 初始化 allow_time 为当前时间 |
9 | 35 | self._allow_time = time.perf_counter() |
| 36 | + self._waiters = _QuerySet() |
10 | 37 |
|
11 | | - self._waiters = 0 |
12 | | - self._permits = 0 |
13 | | - self._base = 0 |
14 | | - |
15 | | - def _allow(self) -> bool: |
16 | | - # 检查当前时间是否在 allow_time 之后 |
17 | | - return time.perf_counter() > self._allow_time |
| 38 | + def _allow(self, id: int) -> bool: |
| 39 | + cur = time.perf_counter() |
| 40 | + # 如果当前时间小于等于 allow_time,不允许通过 |
| 41 | + if cur <= self._allow_time: |
| 42 | + return 0 |
| 43 | + # 如果当前时间与 allow_time 的差值大于 10,允许通过 |
| 44 | + if cur - self._allow_time > 10: |
| 45 | + return True |
| 46 | + # 如果当前时间与 allow_time 的差值小于等于 10,慢启动通过 |
| 47 | + return self._waiters.query(id) < 2 ** (cur - self._allow_time) |
18 | 48 |
|
19 | 49 | def _get_allowed_duration(self) -> float: |
20 | 50 | # 计算当前时间与 allow_time 之间的持续时间 |
21 | 51 | allow_duration = self._allow_time - time.perf_counter() |
22 | 52 |
|
23 | | - # 如果持续时间为负,返回零 |
24 | | - if allow_duration < 0: |
25 | | - return 0 |
26 | | - return allow_duration |
| 53 | + # 至少有 1 秒的等待时间 |
| 54 | + return max(allow_duration, 1) |
27 | 55 |
|
28 | 56 | def _acquire(self) -> int: |
29 | | - self._waiters += 1 |
30 | | - return self._waiters |
31 | | - |
32 | | - def _release(self) -> None: |
33 | | - self._waiters -= 1 |
34 | | - self._permits += 1 |
| 57 | + id = uuid.uuid4().int |
| 58 | + self._waiters.add(id) |
| 59 | + return id |
35 | 60 |
|
36 | | - def _jitter(self, i: int) -> float: |
37 | | - if i <= self._base: |
38 | | - return 0 |
39 | | - return math.log2(i - self._base) |
| 61 | + def _release(self, id: int) -> None: |
| 62 | + self._waiters.remove(id) |
40 | 63 |
|
41 | 64 | def reset(self, duration: float) -> None: |
42 | 65 | # 将 allow_time 重置为当前时间加上指定的持续时间 |
43 | 66 | self._allow_time = time.perf_counter() + duration |
44 | | - self._base = self._permits |
45 | 67 |
|
46 | 68 | def wait(self) -> None: |
47 | | - i = self._acquire() |
48 | | - while not self._allow(): |
49 | | - time.sleep(self._get_allowed_duration() + self._jitter(i)) |
50 | | - self._release() |
| 69 | + id = self._acquire() |
| 70 | + while not self._allow(id): |
| 71 | + time.sleep(self._get_allowed_duration()) |
| 72 | + self._release(id) |
51 | 73 |
|
52 | 74 | async def asyncwait(self) -> None: |
53 | | - i = self._acquire() |
54 | | - while not self._allow(): |
55 | | - await asyncio.sleep(self._get_allowed_duration() + self._jitter(i)) |
56 | | - self._release() |
| 75 | + id = self._acquire() |
| 76 | + while not self._allow(id): |
| 77 | + await asyncio.sleep(self._get_allowed_duration()) |
| 78 | + self._release(id) |
0 commit comments