Skip to content

Commit b89f37a

Browse files
committed
slight refactor - address race conditions
well... light refactor isn't it ? ;-)
1 parent 1b17601 commit b89f37a

File tree

5 files changed

+186
-170
lines changed

5 files changed

+186
-170
lines changed

agent_memory_server/cli.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -129,49 +129,35 @@ async def run_migration():
129129

130130
redis = await get_redis_conn()
131131

132-
# Count keys by type using pipelined TYPE calls
132+
# Scan for string keys only using _type filter (much faster)
133133
string_keys = []
134-
json_keys_count = 0
135134
cursor = 0
136135
pattern = Keys.working_memory_key("*")
137136

138-
click.echo("Scanning for working memory keys...")
137+
click.echo("Scanning for working memory keys (string type only)...")
139138
scan_start = time.time()
140139

141140
while True:
142-
cursor, keys = await redis.scan(cursor, match=pattern, count=batch_size)
141+
# Use _type="string" to only get string keys directly
142+
cursor, keys = await redis.scan(
143+
cursor, match=pattern, count=batch_size, _type="string"
144+
)
143145

144146
if keys:
145-
# Pipeline TYPE calls for better performance
146-
pipe = redis.pipeline()
147-
for key in keys:
148-
pipe.type(key)
149-
types = await pipe.execute()
150-
151-
for key, key_type in zip(keys, types, strict=False):
152-
if isinstance(key_type, bytes):
153-
key_type = key_type.decode("utf-8")
154-
155-
if key_type == "string":
156-
string_keys.append(key)
157-
elif key_type == "ReJSON-RL":
158-
json_keys_count += 1
147+
string_keys.extend(keys)
159148

160149
if cursor == 0:
161150
break
162151

163152
scan_time = time.time() - scan_start
164-
total_keys = len(string_keys) + json_keys_count
165153

166154
click.echo(f"Scan completed in {scan_time:.2f}s")
167-
click.echo(f" Total keys: {total_keys}")
168-
click.echo(f" JSON format: {json_keys_count}")
169155
click.echo(f" String format (need migration): {len(string_keys)}")
170156

171157
if not string_keys:
172158
click.echo("\nNo keys need migration. All done!")
173159
# Mark migration as complete
174-
set_migration_complete()
160+
await set_migration_complete(redis)
175161
return
176162

177163
if dry_run:
@@ -263,7 +249,7 @@ async def run_migration():
263249

264250
if errors == 0:
265251
# Mark migration as complete
266-
set_migration_complete()
252+
await set_migration_complete(redis)
267253
click.echo("\nMigration status set to complete.")
268254
click.echo(
269255
"\n💡 Tip: Set WORKING_MEMORY_MIGRATION_COMPLETE=true to skip "

agent_memory_server/working_memory.py

Lines changed: 94 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,15 @@
2020

2121
logger = logging.getLogger(__name__)
2222

23-
# Flag to track if all string keys have been migrated to JSON
24-
# When True, we skip the type() check and go straight to json().get()
25-
_string_keys_migrated: bool = False
26-
27-
# Counter for remaining string keys (avoids re-scanning after each migration)
28-
_remaining_string_keys: int = 0
23+
# Redis keys for migration status (shared across workers, persists across restarts)
24+
MIGRATION_STATUS_KEY = "working_memory:migration:complete"
25+
MIGRATION_REMAINING_KEY = "working_memory:migration:remaining"
2926

3027

3128
async def check_and_set_migration_status(redis_client: Redis | None = None) -> bool:
3229
"""
3330
Check if any working memory keys are still in old string format.
34-
Sets the global _string_keys_migrated flag and _remaining_string_keys counter.
31+
Stores migration status in Redis for cross-worker consistency.
3532
3633
If WORKING_MEMORY_MIGRATION_COMPLETE=true is set, skips the scan entirely
3734
and assumes all keys are in JSON format.
@@ -42,121 +39,134 @@ async def check_and_set_migration_status(redis_client: Redis | None = None) -> b
4239
Returns:
4340
True if all keys are migrated (or no keys exist), False if string keys remain
4441
"""
45-
global _string_keys_migrated, _remaining_string_keys
46-
4742
# If env variable is set, skip the scan entirely
4843
if settings.working_memory_migration_complete:
4944
logger.info(
5045
"WORKING_MEMORY_MIGRATION_COMPLETE=true, skipping backward compatibility checks."
5146
)
52-
_string_keys_migrated = True
53-
_remaining_string_keys = 0
5447
return True
5548

5649
if not redis_client:
5750
redis_client = await get_redis_conn()
5851

59-
# Scan for working_memory:* keys
52+
# Check if migration status is already stored in Redis
53+
status = await redis_client.get(MIGRATION_STATUS_KEY)
54+
if status:
55+
if isinstance(status, bytes):
56+
status = status.decode("utf-8")
57+
if status == "true":
58+
logger.info(
59+
"Migration status in Redis indicates complete. Skipping type checks."
60+
)
61+
return True
62+
63+
# Scan for working_memory:* keys of type STRING only
64+
# This is much faster than scanning all keys and calling TYPE on each
6065
cursor = 0
61-
json_keys_found = 0
66+
string_keys_found = 0
6267

6368
try:
6469
while True:
70+
# Use _type="string" to only get string keys directly
6571
cursor, keys = await redis_client.scan(
66-
cursor=cursor, match="working_memory:*", count=1000
72+
cursor=cursor, match="working_memory:*", count=1000, _type="string"
6773
)
6874

6975
if keys:
70-
# Use pipeline to batch TYPE calls for better performance
71-
pipe = redis_client.pipeline()
72-
for key in keys:
73-
pipe.type(key)
74-
types = await pipe.execute()
75-
76-
for key_type in types:
77-
if isinstance(key_type, bytes):
78-
key_type = key_type.decode("utf-8")
79-
80-
if key_type == "string":
81-
# Early exit: found at least one string key, enable lazy migration
82-
logger.info(
83-
"Found working memory key in old string format. "
84-
"Lazy migration enabled. Run 'agent-memory migrate-working-memory' "
85-
"to migrate all keys at once."
86-
)
87-
_string_keys_migrated = False
88-
# We don't know the exact count, so set to -1 to indicate unknown
89-
# The counter will be managed differently in this mode
90-
_remaining_string_keys = -1
91-
return False
92-
elif key_type == "ReJSON-RL": # noqa: RET505
93-
json_keys_found += 1
76+
# Filter out migration status keys (they're also strings)
77+
keys = [
78+
k
79+
for k in keys
80+
if (k.decode("utf-8") if isinstance(k, bytes) else k)
81+
not in (MIGRATION_STATUS_KEY, MIGRATION_REMAINING_KEY)
82+
]
83+
string_keys_found += len(keys)
9484

9585
if cursor == 0:
9686
break
9787

98-
# No string keys found
99-
if json_keys_found > 0:
88+
if string_keys_found > 0:
89+
# Store the count in Redis for atomic decrement during lazy migration
90+
await redis_client.set(MIGRATION_REMAINING_KEY, str(string_keys_found))
10091
logger.info(
101-
f"All {json_keys_found} working memory keys are in JSON format. "
102-
"Skipping type checks."
92+
f"Found {string_keys_found} working memory keys in old string format. "
93+
"Lazy migration enabled."
10394
)
104-
else:
105-
logger.info("No working memory keys found. Skipping type checks.")
106-
_string_keys_migrated = True
107-
_remaining_string_keys = 0
95+
return False
96+
97+
# No string keys found - mark as complete in Redis
98+
await redis_client.set(MIGRATION_STATUS_KEY, "true")
99+
await redis_client.delete(MIGRATION_REMAINING_KEY)
100+
101+
logger.info(
102+
"No working memory string keys found. Skipping backward compatibility checks."
103+
)
108104
return True
109105
except Exception as e:
110106
logger.error(f"Failed to check migration status: {e}")
111-
_string_keys_migrated = False # Safe default
112-
_remaining_string_keys = -1
113107
return False
114108

115109

116-
def _decrement_string_key_count() -> None:
110+
async def _decrement_remaining_count(redis_client: Redis) -> None:
117111
"""
118-
Decrement the string key counter after a successful migration.
119-
120-
Note: When _remaining_string_keys is -1, we don't know the exact count
121-
(early exit mode). In this case, lazy migration stays enabled until
122-
the migration script is run.
112+
Atomically decrement the remaining string key counter.
113+
When it reaches 0, mark migration as complete.
123114
"""
124-
global _string_keys_migrated, _remaining_string_keys
115+
try:
116+
remaining = await redis_client.decr(MIGRATION_REMAINING_KEY)
117+
if remaining <= 0:
118+
await redis_client.set(MIGRATION_STATUS_KEY, "true")
119+
await redis_client.delete(MIGRATION_REMAINING_KEY)
120+
logger.info("All working memory keys have been migrated to JSON format.")
121+
except Exception as e:
122+
# Non-fatal - migration still works, just won't auto-complete
123+
logger.warning(f"Failed to decrement migration counter: {e}")
125124

126-
# If we don't know the count (-1), we can't track completion
127-
# The migration script will set the flag when done
128-
if _remaining_string_keys == -1:
129-
return
130125

131-
_remaining_string_keys -= 1
132-
if _remaining_string_keys <= 0:
133-
_remaining_string_keys = 0
134-
_string_keys_migrated = True
135-
logger.info("All working memory keys have been migrated to JSON format.")
126+
async def is_migration_complete(redis_client: Redis | None = None) -> bool:
127+
"""Check if migration is complete."""
128+
if settings.working_memory_migration_complete:
129+
return True
136130

131+
if not redis_client:
132+
redis_client = await get_redis_conn()
137133

138-
def is_migration_complete() -> bool:
139-
"""Check if migration is complete (for testing purposes)."""
140-
return _string_keys_migrated
134+
status = await redis_client.get(MIGRATION_STATUS_KEY)
135+
if status:
136+
if isinstance(status, bytes):
137+
status = status.decode("utf-8")
138+
return status == "true"
139+
return False
141140

142141

143-
def get_remaining_string_keys() -> int:
144-
"""Get the count of remaining string keys (for testing purposes)."""
145-
return _remaining_string_keys
142+
async def get_remaining_string_keys(redis_client: Redis | None = None) -> int:
143+
"""Get the count of remaining string keys (for testing/monitoring)."""
144+
if not redis_client:
145+
redis_client = await get_redis_conn()
146+
147+
remaining = await redis_client.get(MIGRATION_REMAINING_KEY)
148+
if remaining:
149+
if isinstance(remaining, bytes):
150+
remaining = remaining.decode("utf-8")
151+
return int(remaining)
152+
return 0
146153

147154

148-
def reset_migration_status() -> None:
155+
async def reset_migration_status(redis_client: Redis | None = None) -> None:
149156
"""Reset migration status (for testing purposes)."""
150-
global _string_keys_migrated, _remaining_string_keys
151-
_string_keys_migrated = False
152-
_remaining_string_keys = 0
157+
if not redis_client:
158+
redis_client = await get_redis_conn()
159+
160+
await redis_client.delete(MIGRATION_STATUS_KEY, MIGRATION_REMAINING_KEY)
153161

154162

155-
def set_migration_complete() -> None:
163+
async def set_migration_complete(redis_client: Redis | None = None) -> None:
156164
"""Mark migration as complete (called by migration script)."""
157-
global _string_keys_migrated, _remaining_string_keys
158-
_string_keys_migrated = True
159-
_remaining_string_keys = 0
165+
if not redis_client:
166+
redis_client = await get_redis_conn()
167+
168+
await redis_client.set(MIGRATION_STATUS_KEY, "true")
169+
await redis_client.delete(MIGRATION_REMAINING_KEY)
160170
logger.info("Working memory migration marked as complete.")
161171

162172

@@ -202,8 +212,8 @@ async def _migrate_string_to_json(
202212

203213
logger.info(f"Successfully migrated working memory key {key} to JSON format")
204214

205-
# Decrement the counter (O(1) instead of re-scanning all keys)
206-
_decrement_string_key_count()
215+
# Atomically decrement the remaining counter
216+
await _decrement_remaining_count(redis_client)
207217

208218
return data
209219
except json.JSONDecodeError as e:
@@ -292,7 +302,10 @@ async def get_working_memory(
292302
try:
293303
working_memory_data = None
294304

295-
if _string_keys_migrated:
305+
# Check migration status (uses Redis, shared across workers)
306+
migration_complete = await is_migration_complete(redis_client)
307+
308+
if migration_complete:
296309
# Fast path: all keys are already in JSON format
297310
working_memory_data = await redis_client.json().get(key)
298311
else:

tests/benchmarks/test_migration_benchmark.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,13 @@ async def cleanup():
7373

7474
# Clean before
7575
await cleanup()
76-
reset_migration_status()
76+
await reset_migration_status(async_redis_client)
7777

7878
yield
7979

8080
# Clean after
8181
await cleanup()
82-
reset_migration_status()
82+
await reset_migration_status(async_redis_client)
8383

8484

8585
@pytest.mark.benchmark
@@ -120,7 +120,7 @@ async def test_startup_scan_performance(
120120

121121
# Benchmark startup scan (with early exit)
122122
print("\n📊 Benchmarking startup scan (early exit on first string key)...")
123-
reset_migration_status()
123+
await reset_migration_status(async_redis_client)
124124

125125
start = time.perf_counter()
126126
result = await check_and_set_migration_status(async_redis_client)
@@ -154,7 +154,7 @@ async def test_lazy_migration_performance(
154154
await async_redis_client.set(key, json.dumps(data))
155155

156156
# Set migration status
157-
reset_migration_status()
157+
await reset_migration_status(async_redis_client)
158158
await check_and_set_migration_status(async_redis_client)
159159

160160
# Benchmark lazy migration (read each key, triggering migration)
@@ -195,7 +195,7 @@ async def test_post_migration_read_performance(
195195
await pipe.execute()
196196

197197
# Set migration as complete
198-
reset_migration_status()
198+
await reset_migration_status(async_redis_client)
199199
await check_and_set_migration_status(async_redis_client)
200200

201201
# Benchmark reads (should use fast path)
@@ -258,7 +258,7 @@ async def test_worst_case_single_string_key_at_end(
258258

259259
# Benchmark startup scan - must scan all keys to find the string one
260260
print("\n📊 Benchmarking startup scan (worst case - string key at end)...")
261-
reset_migration_status()
261+
await reset_migration_status(async_redis_client)
262262

263263
start = time.perf_counter()
264264
result = await check_and_set_migration_status(async_redis_client)

0 commit comments

Comments
 (0)