Skip to content

Commit 6e60546

Browse files
authored
fix: Convert memory backends to async/await (#2185)
Merging PR #2185 - All tests passing (133+), CI green, philosophy grade 10/10
1 parent 6b035ae commit 6e60546

File tree

10 files changed

+456
-126
lines changed

10 files changed

+456
-126
lines changed

src/amplihack/memory/backends/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import logging
1717
import os
1818
from enum import Enum
19-
from pathlib import Path
2019
from typing import Any
2120

2221
from .base import BackendCapabilities, MemoryBackend
@@ -100,10 +99,10 @@ def create_backend(backend_type: str | BackendType | None = None, **config: Any)
10099
logger.info(f"Kùzu not available ({e}), using SQLite backend")
101100

102101
# Create backend instance
102+
# Note: Caller must call await backend.initialize() after creation
103103
if backend_type == BackendType.SQLITE:
104104
db_path = config.get("db_path")
105105
backend = SQLiteBackend(db_path=db_path)
106-
backend.initialize()
107106
return backend
108107

109108
if backend_type == BackendType.KUZU:
@@ -116,7 +115,6 @@ def create_backend(backend_type: str | BackendType | None = None, **config: Any)
116115

117116
db_path = config.get("db_path")
118117
backend = KuzuBackend(db_path=db_path)
119-
backend.initialize()
120118
return backend
121119

122120
raise ValueError(f"Unknown backend type: {backend_type}")

src/amplihack/memory/backends/base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,15 @@ def get_capabilities(self) -> BackendCapabilities:
5454
"""
5555
...
5656

57-
def initialize(self) -> None:
57+
async def initialize(self) -> None:
5858
"""Initialize backend (create schema, indexes, etc).
5959
6060
MUST be called before first use.
6161
Should be idempotent (safe to call multiple times).
6262
"""
6363
...
6464

65-
def store_memory(self, memory: MemoryEntry) -> bool:
65+
async def store_memory(self, memory: MemoryEntry) -> bool:
6666
"""Store a memory entry.
6767
6868
Args:
@@ -75,7 +75,7 @@ def store_memory(self, memory: MemoryEntry) -> bool:
7575
"""
7676
...
7777

78-
def retrieve_memories(self, query: MemoryQuery) -> list[MemoryEntry]:
78+
async def retrieve_memories(self, query: MemoryQuery) -> list[MemoryEntry]:
7979
"""Retrieve memories matching the query.
8080
8181
Args:
@@ -88,7 +88,7 @@ def retrieve_memories(self, query: MemoryQuery) -> list[MemoryEntry]:
8888
"""
8989
...
9090

91-
def get_memory_by_id(self, memory_id: str) -> MemoryEntry | None:
91+
async def get_memory_by_id(self, memory_id: str) -> MemoryEntry | None:
9292
"""Get a specific memory by ID.
9393
9494
Args:
@@ -101,7 +101,7 @@ def get_memory_by_id(self, memory_id: str) -> MemoryEntry | None:
101101
"""
102102
...
103103

104-
def delete_memory(self, memory_id: str) -> bool:
104+
async def delete_memory(self, memory_id: str) -> bool:
105105
"""Delete a memory entry.
106106
107107
Args:
@@ -114,7 +114,7 @@ def delete_memory(self, memory_id: str) -> bool:
114114
"""
115115
...
116116

117-
def cleanup_expired(self) -> int:
117+
async def cleanup_expired(self) -> int:
118118
"""Remove expired memory entries.
119119
120120
Returns:
@@ -124,7 +124,7 @@ def cleanup_expired(self) -> int:
124124
"""
125125
...
126126

127-
def get_session_info(self, session_id: str) -> SessionInfo | None:
127+
async def get_session_info(self, session_id: str) -> SessionInfo | None:
128128
"""Get information about a session.
129129
130130
Args:
@@ -137,7 +137,7 @@ def get_session_info(self, session_id: str) -> SessionInfo | None:
137137
"""
138138
...
139139

140-
def list_sessions(self, limit: int | None = None) -> list[SessionInfo]:
140+
async def list_sessions(self, limit: int | None = None) -> list[SessionInfo]:
141141
"""List all sessions ordered by last accessed.
142142
143143
Args:
@@ -150,7 +150,7 @@ def list_sessions(self, limit: int | None = None) -> list[SessionInfo]:
150150
"""
151151
...
152152

153-
def get_stats(self) -> dict[str, Any]:
153+
async def get_stats(self) -> dict[str, Any]:
154154
"""Get database statistics.
155155
156156
Returns:
@@ -160,7 +160,7 @@ def get_stats(self) -> dict[str, Any]:
160160
"""
161161
...
162162

163-
def close(self) -> None:
163+
async def close(self) -> None:
164164
"""Close backend connection and cleanup resources.
165165
166166
Should be idempotent (safe to call multiple times).

0 commit comments

Comments
 (0)