|
25 | 25 | from google.adk.memory.memory_entry import MemoryEntry |
26 | 26 | from google.adk.sessions import Session |
27 | 27 | from google.genai import types |
28 | | -from pydantic import BaseModel |
| 28 | +from pydantic import BaseModel, Field |
29 | 29 | from typing_extensions import override |
30 | 30 |
|
31 | | -from veadk.database import DatabaseFactory |
32 | | -from veadk.database.database_adapter import get_long_term_memory_database_adapter |
| 31 | +from veadk.memory.long_term_memory_backends.base_backend import ( |
| 32 | + BaseLongTermMemoryBackend, |
| 33 | +) |
| 34 | +from veadk.memory.long_term_memory_backends.in_memory_backend import InMemoryLTMBackend |
| 35 | +from veadk.memory.long_term_memory_backends.opensearch_backend import ( |
| 36 | + OpensearchLTMBackend, |
| 37 | +) |
| 38 | +from veadk.memory.long_term_memory_backends.redis_backend import RedisLTMBackend |
| 39 | +from veadk.memory.long_term_memory_backends.vikingdb_memory_backend import ( |
| 40 | + VikingDBKnowledgeBackend, |
| 41 | +) |
33 | 42 | from veadk.utils.logger import get_logger |
34 | 43 |
|
35 | 44 | logger = get_logger(__name__) |
36 | 45 |
|
37 | 46 |
|
| 47 | +BACKEND_CLS = { |
| 48 | + "local": InMemoryLTMBackend, |
| 49 | + "opensearch": OpensearchLTMBackend, |
| 50 | + "viking": VikingDBKnowledgeBackend, |
| 51 | + "viking_mem": VikingDBKnowledgeBackend, |
| 52 | + "redis": RedisLTMBackend, |
| 53 | +} |
| 54 | + |
| 55 | + |
38 | 56 | def build_long_term_memory_index(app_name: str, user_id: str): |
39 | 57 | return f"{app_name}_{user_id}" |
40 | 58 |
|
41 | 59 |
|
42 | 60 | class LongTermMemory(BaseMemoryService, BaseModel): |
43 | | - backend: Literal[ |
44 | | - "local", "opensearch", "redis", "mysql", "viking", "viking_mem" |
45 | | - ] = "opensearch" |
| 61 | + backend: Literal["local", "opensearch", "redis", "viking", "viking_mem"] = ( |
| 62 | + "opensearch" |
| 63 | + ) |
| 64 | + """Long term memory backend type""" |
| 65 | + |
| 66 | + backend_config: dict = Field(default_factory=dict) |
| 67 | + """Long term memory backend configuration""" |
| 68 | + |
| 69 | + backend_instance: BaseLongTermMemoryBackend | None = None |
| 70 | + """An instance of a long term memory backend that implements the `BaseLongTermMemoryBackend` interface.""" |
| 71 | + |
46 | 72 | top_k: int = 5 |
| 73 | + """Number of top similar documents to retrieve during search.""" |
47 | 74 |
|
48 | | - def model_post_init(self, __context: Any) -> None: |
49 | | - if self.backend == "viking": |
50 | | - logger.warning( |
51 | | - "`viking` backend is deprecated, switching to `viking_mem` backend." |
52 | | - ) |
53 | | - self.backend = "viking_mem" |
| 75 | + app_name: str = "" |
54 | 76 |
|
55 | | - logger.info( |
56 | | - f"Initializing long term memory: backend={self.backend} top_k={self.top_k}" |
57 | | - ) |
| 77 | + user_id: str = "" |
58 | 78 |
|
59 | | - self._db_client = DatabaseFactory.create( |
60 | | - backend=self.backend, |
61 | | - ) |
62 | | - self._adapter = get_long_term_memory_database_adapter(self._db_client) |
| 79 | + def model_post_init(self, __context: Any) -> None: |
| 80 | + self._backend = None |
63 | 81 |
|
64 | | - logger.info( |
65 | | - f"Initialized long term memory: db_client={self._db_client.__class__.__name__} adapter={self._adapter}" |
66 | | - ) |
| 82 | + if self.backend_instance: |
| 83 | + self._backend = self.backend_instance |
| 84 | + logger.info( |
| 85 | + f"Initialized long term memory with provided backend instance {self._backend.__class__.__name__}" |
| 86 | + ) |
| 87 | + else: |
| 88 | + if self.backend_config: |
| 89 | + logger.info( |
| 90 | + f"Initialized long term memory backend {self.backend} with config." |
| 91 | + ) |
| 92 | + self._backend = BACKEND_CLS[self.backend](**self.backend_config) |
| 93 | + elif self.app_name and self.user_id: |
| 94 | + self.index = build_long_term_memory_index( |
| 95 | + app_name=self.app_name, user_id=self.user_id |
| 96 | + ) |
| 97 | + logger.info(f"Long term memory index set to {self.index}.") |
| 98 | + self._backend = BACKEND_CLS[self.backend]( |
| 99 | + **self.backend_config, index=self.index |
| 100 | + ) |
| 101 | + else: |
| 102 | + logger.warning( |
| 103 | + "Neither `backend_instance`, `backend_config`, nor `app_name`/`user_id` is provided, the long term memory storage will initialize when adding a session." |
| 104 | + ) |
67 | 105 |
|
68 | 106 | def _filter_and_convert_events(self, events: list[Event]) -> list[str]: |
69 | 107 | final_events = [] |
@@ -91,40 +129,48 @@ async def add_session_to_memory( |
91 | 129 | self, |
92 | 130 | session: Session, |
93 | 131 | ): |
| 132 | + app_name = session.app_name |
| 133 | + user_id = session.user_id |
| 134 | + |
| 135 | + if self.index != build_long_term_memory_index(app_name, user_id): |
| 136 | + logger.warning( |
| 137 | + f"The `app_name` or `user_id` is different from the initialized one, skip add session to memory. Initialized index: {self.index}, current built index: {build_long_term_memory_index(app_name, user_id)}" |
| 138 | + ) |
| 139 | + return |
| 140 | + |
| 141 | + if not self._backend: |
| 142 | + self.index = build_long_term_memory_index(app_name, user_id) |
| 143 | + self._backend = BACKEND_CLS[self.backend](index=self.index) |
| 144 | + logger.info( |
| 145 | + f"Initialize long term memory backend now, index is {self.index}" |
| 146 | + ) |
| 147 | + |
94 | 148 | event_strings = self._filter_and_convert_events(session.events) |
95 | | - index = build_long_term_memory_index(session.app_name, session.user_id) |
96 | 149 |
|
97 | 150 | logger.info( |
98 | | - f"Adding {len(event_strings)} events to long term memory: index={index}" |
| 151 | + f"Adding {len(event_strings)} events to long term memory: index={self.index}" |
99 | 152 | ) |
100 | 153 |
|
101 | | - # check if viking memory database, should give a user id: if/else |
102 | | - if self.backend == "viking_mem": |
103 | | - self._adapter.add(data=event_strings, index=index, user_id=session.user_id) |
104 | | - else: |
105 | | - self._adapter.add(data=event_strings, index=index) |
| 154 | + self._backend.save_memory(event_strings=event_strings) |
106 | 155 |
|
107 | 156 | logger.info( |
108 | | - f"Added {len(event_strings)} events to long term memory: index={index}" |
| 157 | + f"Added {len(event_strings)} events to long term memory: index={self.index}" |
109 | 158 | ) |
110 | 159 |
|
111 | 160 | @override |
112 | 161 | async def search_memory(self, *, app_name: str, user_id: str, query: str): |
113 | | - index = build_long_term_memory_index(app_name, user_id) |
114 | | - |
115 | 162 | logger.info( |
116 | | - f"Searching long term memory: query={query} index={index} top_k={self.top_k}" |
| 163 | + f"Searching long term memory: query={query} index={self.index} top_k={self.top_k}" |
117 | 164 | ) |
118 | 165 |
|
119 | | - # user id if viking memory db |
120 | | - if self.backend == "viking_mem": |
121 | | - memory_chunks = self._adapter.query( |
122 | | - query=query, index=index, top_k=self.top_k, user_id=user_id |
123 | | - ) |
124 | | - else: |
125 | | - memory_chunks = self._adapter.query( |
126 | | - query=query, index=index, top_k=self.top_k |
| 166 | + # prevent model invoke `load_memory` before add session to this memory |
| 167 | + if not self._backend: |
| 168 | + logger.error( |
| 169 | + "Long term memory backend is not initialized, cannot search memory." |
127 | 170 | ) |
| 171 | + return SearchMemoryResponse(memories=[]) |
| 172 | + |
| 173 | + memory_chunks = self._backend.search_memory(query=query, top_k=self.top_k) |
128 | 174 |
|
129 | 175 | memory_events = [] |
130 | 176 | for memory in memory_chunks: |
@@ -152,6 +198,6 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str): |
152 | 198 | ) |
153 | 199 |
|
154 | 200 | logger.info( |
155 | | - f"Return {len(memory_events)} memory events for query: {query} index={index}" |
| 201 | + f"Return {len(memory_events)} memory events for query: {query} index={self.index}" |
156 | 202 | ) |
157 | 203 | return SearchMemoryResponse(memories=memory_events) |
0 commit comments