Skip to content

Commit f311845

Browse files
committed
fix state_service redis
1 parent 96a1084 commit f311845

File tree

3 files changed

+399
-7
lines changed

3 files changed

+399
-7
lines changed

src/agentscope_runtime/engine/services/agent_state/redis_state_service.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,68 @@ def __init__(
2121
self,
2222
redis_url: str = "redis://localhost:6379/0",
2323
redis_client: Optional[aioredis.Redis] = None,
24+
socket_timeout: Optional[float] = 5.0,
25+
socket_connect_timeout: Optional[float] = 5.0,
26+
max_connections: Optional[int] = 50,
27+
retry_on_timeout: bool = True,
28+
ttl_seconds: Optional[int] = 3600, # 1 hour in seconds
29+
health_check_interval: Optional[float] = 30.0,
30+
socket_keepalive: bool = True,
2431
):
32+
"""
33+
Initialize RedisStateService.
34+
35+
Args:
36+
redis_url: Redis connection URL
37+
redis_client: Optional pre-configured Redis client
38+
socket_timeout: Socket timeout in seconds (default: 5.0)
39+
socket_connect_timeout: Socket connect timeout in seconds
40+
(default: 5.0)
41+
max_connections: Maximum number of connections in the pool
42+
(default: 50)
43+
retry_on_timeout: Whether to retry on timeout (default: True)
44+
ttl_seconds: Time-to-live in seconds for state data. If None,
45+
data never expires (default: 3600, i.e., 1 hour)
46+
health_check_interval: Interval in seconds for health checks on
47+
idle connections (default: 30.0).
48+
Connections idle longer than this will be checked before reuse.
49+
Set to 0 to disable.
50+
socket_keepalive: Enable TCP keepalive to prevent
51+
silent disconnections (default: True)
52+
"""
2553
self._redis_url = redis_url
2654
self._redis = redis_client
27-
self._health = False
55+
self._socket_timeout = socket_timeout
56+
self._socket_connect_timeout = socket_connect_timeout
57+
self._max_connections = max_connections
58+
self._retry_on_timeout = retry_on_timeout
59+
self._ttl_seconds = ttl_seconds
60+
self._health_check_interval = health_check_interval
61+
self._socket_keepalive = socket_keepalive
2862

2963
async def start(self) -> None:
30-
"""Initialize the Redis connection."""
64+
"""Starts the Redis connection with proper timeout and connection
65+
pool settings."""
3166
if self._redis is None:
3267
self._redis = aioredis.from_url(
3368
self._redis_url,
3469
decode_responses=True,
70+
socket_timeout=self._socket_timeout,
71+
socket_connect_timeout=self._socket_connect_timeout,
72+
max_connections=self._max_connections,
73+
retry_on_timeout=self._retry_on_timeout,
74+
health_check_interval=self._health_check_interval,
75+
socket_keepalive=self._socket_keepalive,
3576
)
36-
self._health = True
3777

3878
async def stop(self) -> None:
39-
"""Close the Redis connection."""
79+
"""Closes the Redis connection."""
4080
if self._redis:
4181
await self._redis.close()
4282
self._redis = None
43-
self._health = False
4483

4584
async def health(self) -> bool:
46-
"""Service health check."""
85+
"""Checks the health of the service."""
4786
if not self._redis:
4887
return False
4988
try:
@@ -81,6 +120,11 @@ async def save_state(
81120
round_id = 1
82121

83122
await self._redis.hset(key, round_id, json.dumps(state))
123+
124+
# Set TTL for the state key if configured
125+
if self._ttl_seconds is not None:
126+
await self._redis.expire(key, self._ttl_seconds)
127+
84128
return round_id
85129

86130
async def export_state(
@@ -110,4 +154,9 @@ async def export_state(
110154

111155
if state_json is None:
112156
return None
157+
158+
# Refresh TTL when accessing the state
159+
if self._ttl_seconds is not None:
160+
await self._redis.expire(key, self._ttl_seconds)
161+
113162
return json.loads(state_json)

src/agentscope_runtime/engine/services/memory/redis_memory_service.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ async def search_memory(
197197
await self._redis.expire(key, ttl_seconds)
198198

199199
return result
200+
200201
async def get_query_text(self, message: Message) -> str:
201202
if message:
202203
if message.type == MessageType.MESSAGE:
@@ -233,7 +234,8 @@ async def list_memory(
233234
# we need all previous messages for proper ordering)
234235
# For now, we keep loading all for correctness
235236

236-
# Refresh TTL on active use to keep memory alive, mirroring get_session behavior
237+
# Refresh TTL on active use to keep memory alive,
238+
# mirroring get_session behavior
237239
if getattr(self, "_ttl_seconds", None):
238240
await self._redis.expire(key, self._ttl_seconds)
239241
return all_msgs[start_index:end_index]

0 commit comments

Comments
 (0)