Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,5 @@ libs/redis/docs/.Trash*
.claude
TASK_MEMORY.md
*.code-workspace

augment*.md
171 changes: 171 additions & 0 deletions agent_memory_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,177 @@ async def run_migrations():
click.echo("Memory migrations completed successfully.")


@cli.command()
@click.option(
"--batch-size",
default=1000,
help="Number of keys to process in each batch",
)
@click.option(
"--dry-run",
is_flag=True,
help="Only count keys without migrating",
)
def migrate_working_memory(batch_size: int, dry_run: bool):
"""
Migrate working memory keys from string format to JSON format.

This command migrates all working memory keys stored in the old string
format (JSON serialized as a string) to the new native Redis JSON format.

Use --dry-run to see how many keys need migration without making changes.
"""
import asyncio
import time

from agent_memory_server.utils.keys import Keys
from agent_memory_server.working_memory import (
reset_migration_status,
set_migration_complete,
)

configure_logging()

async def run_migration():
import json as json_module

redis = await get_redis_conn()

# Count keys by type using pipelined TYPE calls
string_keys = []
json_keys_count = 0
cursor = 0
pattern = Keys.working_memory_key("*")

click.echo("Scanning for working memory keys...")
scan_start = time.time()

while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=batch_size)

if keys:
# Pipeline TYPE calls for better performance
pipe = redis.pipeline()
for key in keys:
pipe.type(key)
types = await pipe.execute()

for key, key_type in zip(keys, types):
if isinstance(key_type, bytes):
key_type = key_type.decode("utf-8")

if key_type == "string":
string_keys.append(key)
elif key_type == "ReJSON-RL":
json_keys_count += 1

if cursor == 0:
break

scan_time = time.time() - scan_start
total_keys = len(string_keys) + json_keys_count

click.echo(f"Scan completed in {scan_time:.2f}s")
click.echo(f" Total keys: {total_keys}")
click.echo(f" JSON format: {json_keys_count}")
click.echo(f" String format (need migration): {len(string_keys)}")

if not string_keys:
click.echo("\nNo keys need migration. All done!")
# Mark migration as complete
set_migration_complete()
return

if dry_run:
click.echo("\n--dry-run specified, no changes made.")
return

# Migrate keys in batches using pipeline
click.echo(f"\nMigrating {len(string_keys)} keys...")
migrate_start = time.time()
migrated = 0
errors = 0

# Process in batches
for batch_start in range(0, len(string_keys), batch_size):
batch_keys = string_keys[batch_start : batch_start + batch_size]

# First, read all string data in a pipeline
read_pipe = redis.pipeline()
for key in batch_keys:
Comment on lines 164 to 179
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The migration command does not preserve TTL when migrating keys from string to JSON format. Before deleting the old key (line 193), get and preserve its TTL:

# Get TTL before deleting
ttl = await redis.ttl(key)
await redis.delete(key)
await redis.json().set(key, "$", data)
# Restore TTL if it was set (ttl > 0)
if ttl > 0:
    await redis.expire(key, ttl)

This ensures keys with expiration don't become permanent after migration.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

read_pipe.get(key)
string_data_list = await read_pipe.execute()

# Parse and prepare migration
migrations = [] # List of (key, data) tuples
for key, string_data in zip(batch_keys, string_data_list):
if string_data is None:
continue

try:
if isinstance(string_data, bytes):
string_data = string_data.decode("utf-8")
data = json_module.loads(string_data)
migrations.append((key, data))
except Exception as e:
errors += 1
logger.error(f"Failed to parse key {key}: {e}")

# Execute migrations in a pipeline (delete + json.set)
if migrations:
write_pipe = redis.pipeline()
for key, data in migrations:
write_pipe.delete(key)
write_pipe.json().set(key, "$", data)

try:
await write_pipe.execute()
migrated += len(migrations)
except Exception as e:
# If batch fails, try one by one
logger.warning(f"Batch migration failed, retrying individually: {e}")
for key, data in migrations:
try:
await redis.delete(key)
await redis.json().set(key, "$", data)
migrated += 1
except Exception as e2:
errors += 1
logger.error(f"Failed to migrate key {key}: {e2}")

# Progress update
total_processed = batch_start + len(batch_keys)
if total_processed % 10000 == 0 or total_processed == len(string_keys):
elapsed = time.time() - migrate_start
rate = migrated / elapsed if elapsed > 0 else 0
remaining = len(string_keys) - total_processed
eta = remaining / rate if rate > 0 else 0
click.echo(
f" Migrated {migrated}/{len(string_keys)} "
f"({rate:.0f} keys/sec, ETA: {eta:.0f}s)"
)

migrate_time = time.time() - migrate_start
rate = migrated / migrate_time if migrate_time > 0 else 0

click.echo(f"\nMigration completed in {migrate_time:.2f}s")
click.echo(f" Migrated: {migrated}")
click.echo(f" Errors: {errors}")
click.echo(f" Rate: {rate:.0f} keys/sec")

if errors == 0:
# Mark migration as complete
set_migration_complete()
click.echo("\nMigration status set to complete.")
else:
click.echo(
"\nMigration completed with errors. "
"Run again to retry failed keys."
)

asyncio.run(run_migration())


@cli.command()
@click.option("--port", default=settings.port, help="Port to run the server on")
@click.option("--host", default="0.0.0.0", help="Host to run the server on")
Expand Down
10 changes: 7 additions & 3 deletions agent_memory_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,13 @@ async def lifespan(app: FastAPI):
"Long-term memory requires OpenAI for embeddings, but OpenAI API key is not set"
)

# Set up Redis connection if long-term memory is enabled
if settings.long_term_memory:
await get_redis_conn()
# Set up Redis connection and check working memory migration status
redis_conn = await get_redis_conn()

# Check if any working memory keys need migration from string to JSON format
from agent_memory_server.working_memory import check_and_set_migration_status

await check_and_set_migration_status(redis_conn)

# Initialize Docket for background tasks if enabled
if settings.use_docket:
Expand Down
Loading