@@ -21,29 +21,68 @@ def __init__(
2121 self ,
2222 redis_url : str = "redis://localhost:6379/0" ,
2323 redis_client : Optional [aioredis .Redis ] = None ,
24+ socket_timeout : Optional [float ] = 5.0 ,
25+ socket_connect_timeout : Optional [float ] = 5.0 ,
26+ max_connections : Optional [int ] = 50 ,
27+ retry_on_timeout : bool = True ,
28+ ttl_seconds : Optional [int ] = 3600 , # 1 hour in seconds
29+ health_check_interval : Optional [float ] = 30.0 ,
30+ socket_keepalive : bool = True ,
2431 ):
32+ """
33+ Initialize RedisStateService.
34+
35+ Args:
36+ redis_url: Redis connection URL
37+ redis_client: Optional pre-configured Redis client
38+ socket_timeout: Socket timeout in seconds (default: 5.0)
39+ socket_connect_timeout: Socket connect timeout in seconds
40+ (default: 5.0)
41+ max_connections: Maximum number of connections in the pool
42+ (default: 50)
43+ retry_on_timeout: Whether to retry on timeout (default: True)
44+ ttl_seconds: Time-to-live in seconds for state data. If None,
45+ data never expires (default: 3600, i.e., 1 hour)
46+ health_check_interval: Interval in seconds for health checks on
47+ idle connections (default: 30.0).
48+ Connections idle longer than this will be checked before reuse.
49+ Set to 0 to disable.
50+ socket_keepalive: Enable TCP keepalive to prevent
51+ silent disconnections (default: True)
52+ """
2553 self ._redis_url = redis_url
2654 self ._redis = redis_client
27- self ._health = False
55+ self ._socket_timeout = socket_timeout
56+ self ._socket_connect_timeout = socket_connect_timeout
57+ self ._max_connections = max_connections
58+ self ._retry_on_timeout = retry_on_timeout
59+ self ._ttl_seconds = ttl_seconds
60+ self ._health_check_interval = health_check_interval
61+ self ._socket_keepalive = socket_keepalive
2862
2963 async def start (self ) -> None :
30- """Initialize the Redis connection."""
64+ """Starts the Redis connection with proper timeout and connection
65+ pool settings."""
3166 if self ._redis is None :
3267 self ._redis = aioredis .from_url (
3368 self ._redis_url ,
3469 decode_responses = True ,
70+ socket_timeout = self ._socket_timeout ,
71+ socket_connect_timeout = self ._socket_connect_timeout ,
72+ max_connections = self ._max_connections ,
73+ retry_on_timeout = self ._retry_on_timeout ,
74+ health_check_interval = self ._health_check_interval ,
75+ socket_keepalive = self ._socket_keepalive ,
3576 )
36- self ._health = True
3777
3878 async def stop (self ) -> None :
39- """Close the Redis connection."""
79+ """Closes the Redis connection."""
4080 if self ._redis :
4181 await self ._redis .close ()
4282 self ._redis = None
43- self ._health = False
4483
4584 async def health (self ) -> bool :
46- """Service health check ."""
85+ """Checks the health of the service ."""
4786 if not self ._redis :
4887 return False
4988 try :
@@ -81,6 +120,11 @@ async def save_state(
81120 round_id = 1
82121
83122 await self ._redis .hset (key , round_id , json .dumps (state ))
123+
124+ # Set TTL for the state key if configured
125+ if self ._ttl_seconds is not None :
126+ await self ._redis .expire (key , self ._ttl_seconds )
127+
84128 return round_id
85129
86130 async def export_state (
@@ -110,4 +154,9 @@ async def export_state(
110154
111155 if state_json is None :
112156 return None
157+
158+ # Refresh TTL when accessing the state
159+ if self ._ttl_seconds is not None :
160+ await self ._redis .expire (key , self ._ttl_seconds )
161+
113162 return json .loads (state_json )
0 commit comments