2
2
from version import __version__
3
3
import redis
4
4
from redis import Redis
5
- from typing import Optional
5
+ from redis .cluster import RedisCluster
6
+ from typing import Optional , Type , Union
6
7
from common .config import REDIS_CFG
7
8
8
9
from common .config import generate_redis_uri
@@ -15,7 +16,9 @@ class RedisConnectionManager:
15
16
def get_connection (cls , decode_responses = True ) -> Redis :
16
17
if cls ._instance is None :
17
18
try :
18
- cls ._instance = redis .Redis (
19
+ redis_class : Type [Union [Redis , RedisCluster ]] = redis .cluster .RedisCluster if REDIS_CFG ["cluster_mode" ] else redis .Redis
20
+
21
+ cls ._instance = redis_class (
19
22
host = REDIS_CFG ["host" ],
20
23
port = REDIS_CFG ["port" ],
21
24
username = REDIS_CFG ["username" ],
@@ -27,8 +30,8 @@ def get_connection(cls, decode_responses=True) -> Redis:
27
30
ssl_cert_reqs = REDIS_CFG ["ssl_cert_reqs" ],
28
31
ssl_ca_certs = REDIS_CFG ["ssl_ca_certs" ],
29
32
decode_responses = decode_responses ,
30
- max_connections = 10 ,
31
- lib_name = f"redis-py(mcp-server_v { __version__ } )"
33
+ lib_name = f"redis-py(mcp-server_v { __version__ } )" ,
34
+ ** ({ "max_connections_per_node" : 10 } if REDIS_CFG [ "cluster_mode" ] else { "max_connections" : 10 })
32
35
)
33
36
34
37
except redis .exceptions .ConnectionError :
@@ -46,6 +49,9 @@ def get_connection(cls, decode_responses=True) -> Redis:
46
49
except redis .exceptions .RedisError as e :
47
50
print (f"Redis error: { e } " , file = sys .stderr )
48
51
raise
52
+ except redis .exceptions .ClusterError as e :
53
+ print (f"Redis Cluster error: { e } " , file = sys .stderr )
54
+ raise
49
55
except Exception as e :
50
56
print (f"Unexpected error: { e } " , file = sys .stderr )
51
57
raise
0 commit comments