1
1
import abc
2
2
import socket
3
3
from time import sleep
4
- from typing import TYPE_CHECKING , Any , Callable , Iterable , Tuple , Type , TypeVar , Union
4
+ from typing import TYPE_CHECKING , Any , Callable , Generic , Iterable , Tuple , Type , TypeVar , Union
5
5
6
6
from redis .exceptions import ConnectionError , TimeoutError
7
7
8
8
T = TypeVar ("T" )
9
+ E = TypeVar ("E" , bound = Exception , covariant = True )
9
10
10
11
if TYPE_CHECKING :
11
12
from redis .backoff import AbstractBackoff
12
13
13
14
14
- class AbstractRetry (abc .ABC ):
15
+ class AbstractRetry (Generic [ E ], abc .ABC ):
15
16
"""Retry a specific number of times after a failure"""
16
17
17
- _supported_errors : Tuple [Type [Exception ], ...]
18
+ _supported_errors : Tuple [Type [E ], ...]
18
19
19
20
def __init__ (
20
21
self ,
21
22
backoff : "AbstractBackoff" ,
22
23
retries : int ,
23
- supported_errors : Union [ Tuple [Type [Exception ], ...], None ] = None ,
24
+ supported_errors : Tuple [Type [E ], ...],
24
25
):
25
26
"""
26
27
Initialize a `Retry` object with a `Backoff` object
@@ -31,8 +32,7 @@ def __init__(
31
32
"""
32
33
self ._backoff = backoff
33
34
self ._retries = retries
34
- if supported_errors :
35
- self ._supported_errors = supported_errors
35
+ self ._supported_errors = supported_errors
36
36
37
37
@abc .abstractmethod
38
38
def __eq__ (self , other : Any ) -> bool :
@@ -42,7 +42,7 @@ def __hash__(self) -> int:
42
42
return hash ((self ._backoff , self ._retries , frozenset (self ._supported_errors )))
43
43
44
44
def update_supported_errors (
45
- self , specified_errors : Iterable [Type [Exception ]]
45
+ self , specified_errors : Iterable [Type [E ]]
46
46
) -> None :
47
47
"""
48
48
Updates the supported errors with the specified error types
@@ -64,14 +64,21 @@ def update_retries(self, value: int) -> None:
64
64
self ._retries = value
65
65
66
66
67
- class Retry (AbstractRetry ):
68
- _supported_errors : Tuple [Type [Exception ], ...] = (
69
- ConnectionError ,
70
- TimeoutError ,
71
- socket .timeout ,
72
- )
67
+ class Retry (AbstractRetry [Exception ]):
73
68
__hash__ = AbstractRetry .__hash__
74
69
70
+ def __init__ (
71
+ self ,
72
+ backoff : "AbstractBackoff" ,
73
+ retries : int ,
74
+ supported_errors : Tuple [Type [Exception ], ...] = (
75
+ ConnectionError , TimeoutError , socket .timeout
76
+ ),
77
+ ):
78
+ super ().__init__ (backoff , retries , supported_errors )
79
+
80
+ __init__ .__doc__ = AbstractRetry .__init__ .__doc__
81
+
75
82
def __eq__ (self , other : Any ) -> bool :
76
83
if not isinstance (other , Retry ):
77
84
return NotImplemented
0 commit comments