|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import random |
| 15 | +import threading |
| 16 | +import time |
| 17 | +import weakref |
15 | 18 |
|
16 | 19 | from collections import namedtuple |
| 20 | +from enum import Enum |
17 | 21 | from functools import lru_cache |
18 | 22 | from itertools import islice, cycle, groupby, repeat |
19 | 23 | import logging |
@@ -778,6 +782,14 @@ def new_schedule(self): |
778 | 782 | raise NotImplementedError() |
779 | 783 |
|
780 | 784 |
|
| 785 | +class NoDelayReconnectionPolicy(ReconnectionPolicy): |
| 786 | + """ |
| 787 | + A :class:`.ReconnectionPolicy` subclass which does not sleep. |
| 788 | + """ |
| 789 | + def new_schedule(self): |
| 790 | + return repeat(0) |
| 791 | + |
| 792 | + |
781 | 793 | class ConstantReconnectionPolicy(ReconnectionPolicy): |
782 | 794 | """ |
783 | 795 | A :class:`.ReconnectionPolicy` subclass which sleeps for a fixed delay |
@@ -864,6 +876,146 @@ def _add_jitter(self, value): |
864 | 876 | return min(max(self.base_delay, delay), self.max_delay) |
865 | 877 |
|
866 | 878 |
|
| 879 | +class ShardReconnectionScheduler(object): |
| 880 | + def schedule(self, host_id, shard_id, method, *args, **kwargs): |
| 881 | + raise NotImplementedError() |
| 882 | + |
| 883 | +class ShardReconnectionPolicy(object): |
| 884 | + def new_scheduler(self, session) -> ShardReconnectionScheduler: |
| 885 | + raise NotImplementedError() |
| 886 | + |
| 887 | + |
| 888 | +class NoDelayShardReconnectionPolicy(ShardReconnectionPolicy): |
| 889 | + def new_scheduler(self, session) -> ShardReconnectionScheduler: |
| 890 | + return NoDelayShardReconnectionScheduler(session) |
| 891 | + |
| 892 | + |
| 893 | +class NoDelayShardReconnectionScheduler(ShardReconnectionScheduler): |
| 894 | + def __init__(self, session): |
| 895 | + self.session = weakref.proxy(session) |
| 896 | + self.already_scheduled = {} |
| 897 | + |
| 898 | + def _execute(self, scheduled_key, method, *args, **kwargs): |
| 899 | + try: |
| 900 | + method(*args, **kwargs) |
| 901 | + finally: |
| 902 | + self.already_scheduled[scheduled_key] = False |
| 903 | + |
| 904 | + def schedule(self, host_id, shard_id, method, *args, **kwargs): |
| 905 | + scheduled_key = f'{host_id}-{shard_id}' |
| 906 | + if self.already_scheduled.get(scheduled_key): |
| 907 | + return |
| 908 | + |
| 909 | + self.already_scheduled[scheduled_key] = True |
| 910 | + if not self.session.is_shutdown: |
| 911 | + self.session.submit(self._execute, scheduled_key, method, *args, **kwargs) |
| 912 | + |
| 913 | + |
| 914 | +class ShardReconnectionPolicyScope(Enum): |
| 915 | + Cluster = 0 |
| 916 | + Host = 1 |
| 917 | + |
| 918 | + |
| 919 | +class NoConcurrentShardReconnectionPolicy(ShardReconnectionPolicy): |
| 920 | + def __init__(self, shard_reconnection_scope, reconnection_policy): |
| 921 | + if not isinstance(shard_reconnection_scope, ShardReconnectionPolicyScope): |
| 922 | + raise ValueError("shard_reconnection_scope must be a ShardReconnectionPolicyScope") |
| 923 | + if not isinstance(reconnection_policy, ReconnectionPolicy): |
| 924 | + raise ValueError("reconnection_policy must be a ReconnectionPolicy") |
| 925 | + self.shard_reconnection_scope = shard_reconnection_scope |
| 926 | + self.reconnection_policy = reconnection_policy |
| 927 | + |
| 928 | + def new_scheduler(self, session) -> ShardReconnectionScheduler: |
| 929 | + return NoConcurrentShardReconnectionScheduler(session, self.shard_reconnection_scope, self.reconnection_policy) |
| 930 | + |
| 931 | + |
| 932 | +class _ScopeBucket(object): |
| 933 | + def __init__(self, session, shard_reconnection_policy): |
| 934 | + self._items = [] |
| 935 | + self.last_run = None |
| 936 | + self.session = session |
| 937 | + self.policy = shard_reconnection_policy |
| 938 | + self.lock = threading.Lock() |
| 939 | + self.running = False |
| 940 | + self.schedule = self.policy.new_schedule() |
| 941 | + |
| 942 | + def add(self, method, *args, **kwargs): |
| 943 | + with self.lock: |
| 944 | + self._items.append([method, args, kwargs]) |
| 945 | + if not self.running: |
| 946 | + self.running = True |
| 947 | + self._schedule() |
| 948 | + |
| 949 | + def _get_delay(self): |
| 950 | + try: |
| 951 | + return next(self.schedule) |
| 952 | + except StopIteration: |
| 953 | + self.schedule = self.policy.new_schedule() |
| 954 | + return next(self.schedule) |
| 955 | + |
| 956 | + def _schedule(self): |
| 957 | + if self.session.is_shutdown: |
| 958 | + return |
| 959 | + delay = self._get_delay() |
| 960 | + if delay: |
| 961 | + self.session.cluster.scheduler.schedule(delay, self.run) |
| 962 | + else: |
| 963 | + self.session.submit(self.run) |
| 964 | + |
| 965 | + def run(self): |
| 966 | + if self.session.is_shutdown: |
| 967 | + return |
| 968 | + |
| 969 | + with self.lock: |
| 970 | + try: |
| 971 | + item = self._items.pop() |
| 972 | + except IndexError: |
| 973 | + self.running = False |
| 974 | + return |
| 975 | + |
| 976 | + method, args, kwargs = item |
| 977 | + try: |
| 978 | + method(*args, **kwargs) |
| 979 | + finally: |
| 980 | + self._schedule() |
| 981 | + |
| 982 | + |
| 983 | +class NoConcurrentShardReconnectionScheduler(ShardReconnectionScheduler): |
| 984 | + def __init__(self, session, shard_reconnection_scope, reconnection_policy): |
| 985 | + self.already_scheduled = {} |
| 986 | + self.scopes = {} |
| 987 | + self.shard_reconnection_scope = shard_reconnection_scope |
| 988 | + self.reconnection_policy = reconnection_policy |
| 989 | + self.session = session |
| 990 | + self.lock = threading.Lock() |
| 991 | + |
| 992 | + def _execute(self, scheduled_key, method, *args, **kwargs): |
| 993 | + try: |
| 994 | + method(*args, **kwargs) |
| 995 | + finally: |
| 996 | + with self.lock: |
| 997 | + self.already_scheduled[scheduled_key] = False |
| 998 | + |
| 999 | + def schedule(self, host_id, shard_id, method, *args, **kwargs): |
| 1000 | + if self.shard_reconnection_scope == ShardReconnectionPolicyScope.Cluster: |
| 1001 | + scope_hash = "global-cluster-scope" |
| 1002 | + else: |
| 1003 | + scope_hash = host_id |
| 1004 | + scheduled_key = f'{host_id}-{shard_id}' |
| 1005 | + |
| 1006 | + with self.lock: |
| 1007 | + if self.already_scheduled.get(scheduled_key): |
| 1008 | + return False |
| 1009 | + self.already_scheduled[scheduled_key] = True |
| 1010 | + |
| 1011 | + scope_info = self.scopes.get(scope_hash, 0) |
| 1012 | + if not scope_info: |
| 1013 | + scope_info = _ScopeBucket(self.session, self.reconnection_policy) |
| 1014 | + self.scopes[scope_hash] = scope_info |
| 1015 | + scope_info.add(self._execute, scheduled_key, method,*args, **kwargs) |
| 1016 | + return True |
| 1017 | + |
| 1018 | + |
867 | 1019 | class RetryPolicy(object): |
868 | 1020 | """ |
869 | 1021 | A policy that describes whether to retry, rethrow, or ignore coordinator |
|
0 commit comments