Skip to content

Commit 19ee7ba

Browse files
committed
feat: support nebular database
1 parent 9eb6c5c commit 19ee7ba

File tree

2 files changed

+118
-69
lines changed

2 files changed

+118
-69
lines changed

src/memos/graph_dbs/nebular.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -281,29 +281,25 @@ def get_memory_count(self, memory_type: str) -> int:
281281
query += "\nRETURN COUNT(n) AS count"
282282

283283
try:
284-
print(f"\n ======> query: {query}\n")
285284
result = self.client.execute(query)
286-
print(result.one_or_none()["count"].value)
287285
return result.one_or_none()["count"].value
288286
except Exception as e:
289287
logger.error(f"[get_memory_count] Failed: {e}")
290288
return -1
291289

292-
# TODO
293290
def count_nodes(self, scope: str) -> int:
294291
query = f"""
295292
MATCH (n@Memory)
296293
WHERE n.memory_type = {scope}
297294
"""
298295
if not self.config.use_multi_db and self.config.user_name:
299296
user_name = self.config.user_name
300-
query += f"\nAND n.user_name = {user_name}"
297+
query += f"\nAND n.user_name = '{user_name}'"
301298
query += "\nRETURN count(n) AS count"
302299

303300
result = self.client.execute(query)
304-
return result.one_or_none().values()["count"]
301+
return result.one_or_none()["count"].value
305302

306-
# TODO
307303
def edge_exists(
308304
self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING"
309305
) -> bool:
@@ -319,26 +315,27 @@ def edge_exists(
319315
True if the edge exists, otherwise False.
320316
"""
321317
# Prepare the relationship pattern
322-
rel = "r" if type == "ANY" else f"r:{type}"
318+
rel = "r" if type == "ANY" else f"r@{type}"
323319

324320
# Prepare the match pattern with direction
325321
if direction == "OUTGOING":
326-
pattern = f"(a@Memory {{id: {source_id}}})-[{rel}]->(b@Memory {{id: {target_id}}})"
322+
pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]->(b@Memory {{id: '{target_id}'}})"
327323
elif direction == "INCOMING":
328-
pattern = f"(a@Memory {{id: {source_id}}})<-[{rel}]-(b@Memory {{id: {target_id}}})"
324+
pattern = f"(a@Memory {{id: '{source_id}'}})<-[{rel}]-(b@Memory {{id: '{target_id}'}})"
329325
elif direction == "ANY":
330-
pattern = f"(a@Memory {{id: {source_id}}})-[{rel}]-(b@Memory {{id: {target_id}}})"
326+
pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]-(b@Memory {{id: '{target_id}'}})"
331327
else:
332328
raise ValueError(
333329
f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'."
334330
)
335331
query = f"MATCH {pattern}"
336332
if not self.config.use_multi_db and self.config.user_name:
337333
user_name = self.config.user_name
338-
query += f"\nWHERE a.user_name = {user_name} AND b.user_name = {user_name}"
334+
query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
339335
query += "\nRETURN r"
340336

341337
# Run the Cypher query
338+
print("\n ======> query: ", query)
342339
result = self.client.execute(query)
343340
return result.one_or_none().values() is not None
344341

@@ -374,7 +371,6 @@ def get_node(self, id: str) -> dict[str, Any] | None:
374371
logger.error(f"[get_node] Failed to retrieve node '{id}': {e}")
375372
return None
376373

377-
# TODO
378374
def get_nodes(self, ids: list[str]) -> list[dict[str, Any]]:
379375
"""
380376
Retrieve the metadata and memory of a list of nodes.
@@ -392,14 +388,13 @@ def get_nodes(self, ids: list[str]) -> list[dict[str, Any]]:
392388

393389
where_user = ""
394390
if not self.config.use_multi_db and self.config.user_name:
395-
where_user = f" AND n.user_name = {self.config.user_name}"
391+
where_user = f" AND n.user_name = '{self.config.user_name}'"
396392

397393
query = f"MATCH (n@Memory) WHERE n.id IN {ids} {where_user} RETURN n"
398394

399395
results = self.client.execute(query)
400-
return [self._parse_node(dict(record.values()["n"])) for record in results]
396+
return [self._parse_node(record["n"]) for record in results]
401397

402-
# TODO
403398
def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]:
404399
"""
405400
Get edges connected to a node, with optional type and direction filter.
@@ -417,28 +412,28 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[
417412
]
418413
"""
419414
# Build relationship type filter
420-
rel_type = "" if type == "ANY" else f":{type}"
415+
rel_type = "" if type == "ANY" else f"@{type}"
421416

422417
# Build Cypher pattern based on direction
423418
if direction == "OUTGOING":
424419
pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)"
425-
where_clause = f"a.id = {id}"
420+
where_clause = f"a.id = '{id}'"
426421
elif direction == "INCOMING":
427422
pattern = f"(a@Memory)<-[r{rel_type}]-(b@Memory)"
428-
where_clause = f"a.id = {id}"
423+
where_clause = f"a.id = '{id}'"
429424
elif direction == "ANY":
430425
pattern = f"(a@Memory)-[r{rel_type}]-(b@Memory)"
431426
where_clause = f"a.id = {id} OR b.id = {id}"
432427
else:
433428
raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.")
434429

435430
if not self.config.use_multi_db and self.config.user_name:
436-
where_clause += f" AND a.user_name = {self.config.user_name} AND b.user_name = {self.config.user_name}"
431+
where_clause += f" AND a.user_name = '{self.config.user_name}' AND b.user_name = '{self.config.user_name}'"
437432

438433
query = f"""
439434
MATCH {pattern}
440435
WHERE {where_clause}
441-
RETURN a.id AS from_id, b.id AS to_id, type(r) AS type
436+
RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
442437
"""
443438

444439
result = self.client.execute(query)
@@ -448,11 +443,12 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[
448443
{
449444
"from": record["from_id"].value,
450445
"to": record["to_id"].value,
451-
"type": record["type"].value,
446+
"type": record["edge_type"].value,
452447
}
453448
)
454449
return edges
455450

451+
# TODO
456452
def get_neighbors_by_tag(
457453
self,
458454
tags: list[str],
@@ -544,7 +540,7 @@ def get_subgraph(
544540
collect(EDGES(p)) AS edge_chains
545541
"""
546542

547-
result = self.client.execute(gql).one_or_none() # 执行查询
543+
result = self.client.execute(gql).one_or_none()
548544
if not result or result.size == 0:
549545
return {"core_node": None, "neighbors": [], "edges": []}
550546

@@ -734,7 +730,6 @@ def get_grouped_counts(
734730

735731
return output
736732

737-
# TODO
738733
def clear(self) -> None:
739734
"""
740735
Clear the entire graph if the target database exists.
@@ -745,13 +740,11 @@ def clear(self) -> None:
745740
else:
746741
query = "MATCH (n) DETACH DELETE n"
747742

748-
# Step 2: Clear the graph in that database
749743
self.client.execute(query)
750744
logger.info("Cleared all nodes from database.")
751745

752746
except Exception as e:
753747
logger.error(f"[ERROR] Failed to clear database: {e}")
754-
raise
755748

756749
def export_graph(self) -> dict[str, Any]:
757750
"""

tests/graph_dbs/test_nebular.py

Lines changed: 100 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
"base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
2626
},
2727
}
28-
2928
nebular_config = {
3029
"hosts": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
3130
"user_name": os.getenv("NEBULAR_USER", "root"),
@@ -35,6 +34,8 @@
3534
"embedding_dimension": 3072,
3635
"use_multi_db": False,
3736
}
37+
38+
3839
embedder_config = EmbedderConfigFactory.model_validate(gpt_config)
3940
embedder = EmbedderFactory.from_config(embedder_config)
4041

@@ -46,60 +47,95 @@ def embed_memory_item(memory: str) -> list[float]:
4647
return embedding_list
4748

4849

50+
now = datetime.now(timezone.utc).isoformat()
51+
test_node1 = TextualMemoryItem(
52+
memory="This is a test node",
53+
metadata=TreeNodeTextualMemoryMetadata(
54+
memory_type="LongTermMemory",
55+
key="Research Topic",
56+
hierarchy_level="topic",
57+
type="fact",
58+
memory_time="2024-01-01",
59+
status="activated",
60+
visibility="public",
61+
updated_at=now,
62+
embedding=embed_memory_item("This is a test node"),
63+
),
64+
)
65+
66+
test_node2 = TextualMemoryItem(
67+
memory="This is another test node",
68+
metadata=TreeNodeTextualMemoryMetadata(
69+
memory_type="LongTermMemory",
70+
key="Research Topic",
71+
hierarchy_level="topic",
72+
type="fact",
73+
memory_time="2024-01-01",
74+
status="activated",
75+
visibility="public",
76+
updated_at=now,
77+
embedding=embed_memory_item("This is another test node"),
78+
),
79+
)
80+
81+
4982
def test_get_memory_count():
5083
config = GraphDBConfigFactory(backend="nebular", config=nebular_config)
51-
5284
graph = GraphStoreFactory.from_config(config)
5385
graph.clear()
54-
now = datetime.now(timezone.utc).isoformat()
5586

56-
# Insert memory
57-
mem = TextualMemoryItem(
58-
memory="Example A",
59-
metadata=TreeNodeTextualMemoryMetadata(
60-
memory_type="LongTermMemory",
61-
key="Research Topic",
62-
hierarchy_level="topic",
63-
type="fact",
64-
memory_time="2024-01-01",
65-
status="activated",
66-
visibility="public",
67-
updated_at=now,
68-
embedding=embed_memory_item("Example A"),
69-
),
70-
)
87+
mem = test_node1
7188
graph.add_node(mem.id, mem.memory, mem.metadata.model_dump(exclude_none=True))
7289

7390
count = graph.get_memory_count('"LongTermMemory"') # quoting string literal for Cypher
7491
print("Memory Count:", count)
7592
assert count == 1
7693

7794

95+
def test_count_nodes():
96+
graph = GraphStoreFactory.from_config(
97+
GraphDBConfigFactory(
98+
backend="nebular",
99+
config=nebular_config,
100+
)
101+
)
102+
graph.clear()
103+
104+
# Insert two nodes
105+
for i in range(2):
106+
mem = TextualMemoryItem(
107+
memory=f"Memory {i}",
108+
metadata=TreeNodeTextualMemoryMetadata(
109+
memory_type="LongTermMemory",
110+
key="Research Topic",
111+
hierarchy_level="topic",
112+
type="fact",
113+
memory_time="2024-01-01",
114+
status="activated",
115+
visibility="public",
116+
updated_at=now,
117+
embedding=embed_memory_item(f"Memory {i}"),
118+
),
119+
)
120+
graph.add_node(mem.id, mem.memory, mem.metadata.model_dump(exclude_none=True))
121+
122+
count = graph.count_nodes('"LongTermMemory"')
123+
print("Node Count:", count)
124+
assert count == 2
125+
126+
78127
def test_get_nodes():
79128
graph = GraphStoreFactory.from_config(
80129
GraphDBConfigFactory(backend="nebular", config=nebular_config)
81130
)
82131
graph.clear()
83132

84-
now = datetime.now(timezone.utc).isoformat()
85-
mem = TextualMemoryItem(
86-
memory="Test node",
87-
metadata=TreeNodeTextualMemoryMetadata(
88-
memory_type="LongTermMemory",
89-
key="Research Topic",
90-
hierarchy_level="topic",
91-
type="fact",
92-
memory_time="2024-01-01",
93-
status="activated",
94-
visibility="public",
95-
updated_at=now,
96-
embedding=embed_memory_item("Test node"),
97-
),
98-
)
133+
mem = test_node1
99134
graph.add_node(mem.id, mem.memory, mem.metadata.model_dump(exclude_none=True))
135+
100136
nodes = graph.get_nodes([mem.id])
101137
assert len(nodes) == 1
102-
assert nodes[0]["id"] == mem.id
138+
assert nodes[0]["properties"]["id"] == mem.id
103139

104140

105141
def test_edge_exists():
@@ -112,8 +148,13 @@ def test_edge_exists():
112148
memory="Edge topic",
113149
metadata=TreeNodeTextualMemoryMetadata(
114150
memory_type="LongTermMemory",
115-
key="topic",
116-
updated_at=datetime.now().isoformat(),
151+
key="Research Topic",
152+
hierarchy_level="topic",
153+
type="fact",
154+
memory_time="2024-01-01",
155+
status="activated",
156+
visibility="public",
157+
updated_at=now,
117158
embedding=embed_memory_item("Edge topic"),
118159
),
119160
)
@@ -122,8 +163,13 @@ def test_edge_exists():
122163
memory="Edge concept",
123164
metadata=TreeNodeTextualMemoryMetadata(
124165
memory_type="LongTermMemory",
125-
key="concept",
126-
updated_at=datetime.now().isoformat(),
166+
key="Research Topic",
167+
hierarchy_level="topic",
168+
type="fact",
169+
memory_time="2024-01-01",
170+
status="activated",
171+
visibility="public",
172+
updated_at=now,
127173
embedding=embed_memory_item("Edge concept"),
128174
),
129175
)
@@ -144,18 +190,28 @@ def test_get_edges():
144190
source = TextualMemoryItem(
145191
memory="Source",
146192
metadata=TreeNodeTextualMemoryMetadata(
147-
memory_type="WorkingMemory",
148-
key="src",
149-
updated_at=datetime.now().isoformat(),
193+
memory_type="LongTermMemory",
194+
key="Research Topic",
195+
hierarchy_level="topic",
196+
type="fact",
197+
memory_time="2024-01-01",
198+
status="activated",
199+
visibility="public",
200+
updated_at=now,
150201
embedding=embed_memory_item("Source"),
151202
),
152203
)
153204
target = TextualMemoryItem(
154205
memory="Target",
155206
metadata=TreeNodeTextualMemoryMetadata(
156-
memory_type="WorkingMemory",
157-
key="tgt",
158-
updated_at=datetime.now().isoformat(),
207+
memory_type="LongTermMemory",
208+
key="Research Topic",
209+
hierarchy_level="topic",
210+
type="fact",
211+
memory_time="2024-01-01",
212+
status="activated",
213+
visibility="public",
214+
updated_at=now,
159215
embedding=embed_memory_item("Target"),
160216
),
161217
)

0 commit comments

Comments
 (0)