11package org .springframework .ai .chat .memory .neo4j ;
22
3- import org .neo4j .driver .Result ;
43import org .neo4j .driver .Session ;
54import org .neo4j .driver .Transaction ;
5+ import org .neo4j .driver .TransactionContext ;
66import org .springframework .ai .chat .memory .ChatMemoryRepository ;
77import org .springframework .ai .chat .messages .*;
88import org .springframework .ai .content .Media ;
1111
1212import java .net .URI ;
1313import java .util .*;
14+ import java .util .stream .Collectors ;
1415
1516/**
1617 * An implementation of {@link ChatMemoryRepository} for Neo4J
1718 *
1819 * @author Enrico Rampazzo
20+ * @author Michael J. Simons
1921 * @since 1.0.0
2022 */
2123
22- public class Neo4jChatMemoryRepository implements ChatMemoryRepository {
24+ public final class Neo4jChatMemoryRepository implements ChatMemoryRepository {
2325
2426 private final Neo4jChatMemoryConfig config ;
2527
@@ -29,63 +31,65 @@ public Neo4jChatMemoryRepository(Neo4jChatMemoryConfig config) {
2931
3032 @ Override
3133 public List <String > findConversationIds () {
32- try (var session = config .getDriver ().session ()) {
33- return session .run ("MATCH (conversation:%s) RETURN conversation.id" .formatted (config .getSessionLabel ()))
34- .stream ()
35- .map (r -> r .get ("conversation.id" ).asString ())
36- .toList ();
37- }
34+ return config .getDriver ()
35+ .executableQuery ("MATCH (conversation:$($sessionLabel)) RETURN conversation.id" )
36+ .withParameters (Map .of ("sessionLabel" , config .getSessionLabel ()))
37+ .execute (Collectors .mapping (r -> r .get ("conversation.id" ).asString (), Collectors .toList ()));
3838 }
3939
4040 @ Override
4141 public List <Message > findByConversationId (String conversationId ) {
42- String statementBuilder = """
43- MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s )
42+ String statement = """
43+ MATCH (s:$($sessionLabel) {id:$conversationId})-[r:HAS_MESSAGE]->(m:$($messageLabel) )
4444 WITH m
45- OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s )
46- OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s ) WITH m, metadata, media ORDER BY media.idx ASC
47- OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s ) WITH m, metadata, media, tr ORDER BY tr.idx ASC
48- OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s )
45+ OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:$($metadataLabel) )
46+ OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:$($mediaLabel) ) WITH m, metadata, media ORDER BY media.idx ASC
47+ OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:$($toolResponseLabel) ) WITH m, metadata, media, tr ORDER BY tr.idx ASC
48+ OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:$($toolCallLabel) )
4949 WITH m, metadata, media, tr, tc ORDER BY tc.idx ASC
5050 RETURN m, metadata, collect(tr) as toolResponses, collect(tc) as toolCalls, collect(media) as medias
5151 ORDER BY m.idx ASC
52- """ .formatted (this .config .getSessionLabel (), this .config .getMessageLabel (),
53- this .config .getMetadataLabel (), this .config .getMediaLabel (), this .config .getToolResponseLabel (),
54- this .config .getToolCallLabel ());
55- Result res = this .config .getDriver ().session ().run (statementBuilder , Map .of ("conversationId" , conversationId ));
56- return res .stream ().map (record -> {
57- Map <String , Object > messageMap = record .get ("m" ).asMap ();
58- String msgType = messageMap .get (MessageAttributes .MESSAGE_TYPE .getValue ()).toString ();
59- Message message = null ;
60- List <Media > mediaList = List .of ();
61- if (!record .get ("medias" ).isNull ()) {
62- mediaList = getMedia (record );
63- }
64- if (msgType .equals (MessageType .USER .getValue ())) {
65- message = buildUserMessage (record , messageMap , mediaList );
66- }
67- if (msgType .equals (MessageType .ASSISTANT .getValue ())) {
68- message = buildAssistantMessage (record , messageMap , mediaList );
69- }
70- if (msgType .equals (MessageType .SYSTEM .getValue ())) {
71- SystemMessage .Builder systemMessageBuilder = SystemMessage .builder ()
72- .text (messageMap .get (MessageAttributes .TEXT_CONTENT .getValue ()).toString ());
73- if (!record .get ("metadata" ).isNull ()) {
74- Map <String , Object > retrievedMetadata = record .get ("metadata" ).asMap ();
75- systemMessageBuilder .metadata (retrievedMetadata );
52+ """ ;
53+
54+ return this .config .getDriver ()
55+ .executableQuery (statement )
56+ .withParameters (Map .of ("conversationId" , conversationId , "sessionLabel" , this .config .getSessionLabel (),
57+ "messageLabel" , this .config .getMessageLabel (), "metadataLabel" , this .config .getMetadataLabel (),
58+ "mediaLabel" , this .config .getMediaLabel (), "toolResponseLabel" , this .config .getToolResponseLabel (),
59+ "toolCallLabel" , this .config .getToolCallLabel ()))
60+ .execute (Collectors .mapping (record -> {
61+ Map <String , Object > messageMap = record .get ("m" ).asMap ();
62+ String msgType = messageMap .get (MessageAttributes .MESSAGE_TYPE .getValue ()).toString ();
63+ Message message = null ;
64+ List <Media > mediaList = List .of ();
65+ if (!record .get ("medias" ).isNull ()) {
66+ mediaList = getMedia (record );
7667 }
77- message = systemMessageBuilder .build ();
78- }
79- if (msgType .equals (MessageType .TOOL .getValue ())) {
80- message = buildToolMessage (record );
81- }
82- if (message == null ) {
83- throw new IllegalArgumentException ("%s messages are not supported"
84- .formatted (record .get (MessageAttributes .MESSAGE_TYPE .getValue ()).asString ()));
85- }
86- message .getMetadata ().put ("messageType" , message .getMessageType ());
87- return message ;
88- }).toList ();
68+ if (msgType .equals (MessageType .USER .getValue ())) {
69+ message = buildUserMessage (record , messageMap , mediaList );
70+ }
71+ if (msgType .equals (MessageType .ASSISTANT .getValue ())) {
72+ message = buildAssistantMessage (record , messageMap , mediaList );
73+ }
74+ if (msgType .equals (MessageType .SYSTEM .getValue ())) {
75+ SystemMessage .Builder systemMessageBuilder = SystemMessage .builder ()
76+ .text (messageMap .get (MessageAttributes .TEXT_CONTENT .getValue ()).toString ());
77+ if (!record .get ("metadata" ).isNull ()) {
78+ Map <String , Object > retrievedMetadata = record .get ("metadata" ).asMap ();
79+ systemMessageBuilder .metadata (retrievedMetadata );
80+ }
81+ message = systemMessageBuilder .build ();
82+ }
83+ if (msgType .equals (MessageType .TOOL .getValue ())) {
84+ message = buildToolMessage (record );
85+ }
86+ if (message == null ) {
87+ throw new IllegalArgumentException ("%s messages are not supported"
88+ .formatted (record .get (MessageAttributes .MESSAGE_TYPE .getValue ()).asString ()));
89+ }
90+ message .getMetadata ().put ("messageType" , message .getMessageType ());
91+ return message ;
92+ }, Collectors .toList ()));
8993
9094 }
9195
@@ -96,12 +100,11 @@ public void saveAll(String conversationId, List<Message> messages) {
96100
97101 // Then add the new messages
98102 try (Session s = this .config .getDriver ().session ()) {
99- try ( Transaction t = s . beginTransaction ()) {
103+ s . executeWriteWithoutResult ( tx -> {
100104 for (Message m : messages ) {
101- addMessageToTransaction (t , conversationId , m );
105+ addMessageToTransaction (tx , conversationId , m );
102106 }
103- t .commit ();
104- }
107+ });
105108 }
106109 }
107110
@@ -196,42 +199,46 @@ else if (mediaMap.get(MediaAttributes.DATA.getValue()).getClass().isArray()) {
196199 return mediaList ;
197200 }
198201
199- private void addMessageToTransaction (Transaction t , String conversationId , Message message ) {
202+ private void addMessageToTransaction (TransactionContext t , String conversationId , Message message ) {
200203 Map <String , Object > queryParameters = new HashMap <>();
201204 queryParameters .put ("conversationId" , conversationId );
202205 StringBuilder statementBuilder = new StringBuilder ();
203206 statementBuilder .append ("""
204- MERGE (s:%s {id:$conversationId}) WITH s
205- OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), 0) as totalMsg, s
206- CREATE (s)-[:HAS_MESSAGE]->(msg:%s) SET msg = $messageProperties
207+ MERGE (s:$($sessionLabel) {id:$conversationId}) WITH s
208+ OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:$($messageLabel))
209+ WITH coalesce(count(countMsg), 0) as totalMsg, s
210+ CREATE (s)-[:HAS_MESSAGE]->(msg:$($messageLabel)) SET msg = $messageProperties
207211 SET msg.idx = totalMsg + 1
208- """ .formatted (this .config .getSessionLabel (), this .config .getMessageLabel (),
209- this .config .getMessageLabel ()));
212+ """ );
210213 Map <String , Object > attributes = new HashMap <>();
211214
212215 attributes .put (MessageAttributes .MESSAGE_TYPE .getValue (), message .getMessageType ().getValue ());
213216 attributes .put (MessageAttributes .TEXT_CONTENT .getValue (), message .getText ());
214217 attributes .put ("id" , UUID .randomUUID ().toString ());
215218 queryParameters .put ("messageProperties" , attributes );
219+ queryParameters .put ("sessionLabel" , this .config .getSessionLabel ());
220+ queryParameters .put ("messageLabel" , this .config .getMessageLabel ());
216221
217222 if (!Optional .ofNullable (message .getMetadata ()).orElse (Map .of ()).isEmpty ()) {
218223 statementBuilder .append ("""
219224 WITH msg
220- CREATE (metadataNode:%s )
225+ CREATE (metadataNode:$($metadataLabel) )
221226 CREATE (msg)-[:HAS_METADATA]->(metadataNode)
222227 SET metadataNode = $metadata
223- """ . formatted ( this . config . getMetadataLabel ()) );
228+ """ );
224229 Map <String , Object > metadataCopy = new HashMap <>(message .getMetadata ());
225230 metadataCopy .remove ("messageType" );
226231 queryParameters .put ("metadata" , metadataCopy );
232+ queryParameters .put ("metadataLabel" , this .config .getMetadataLabel ());
227233 }
228234 if (message instanceof AssistantMessage assistantMessage ) {
229235 if (assistantMessage .hasToolCalls ()) {
230236 statementBuilder .append ("""
231237 WITH msg
232- FOREACH(tc in $toolCalls | CREATE (toolCall:%s ) SET toolCall = tc
238+ FOREACH(tc in $toolCalls | CREATE (toolCall:$($toolLabel) ) SET toolCall = tc
233239 CREATE (msg)-[:HAS_TOOL_CALL]->(toolCall))
234- """ .formatted (this .config .getToolCallLabel ()));
240+ """ );
241+ queryParameters .put ("toolLabel" , this .config .getToolCallLabel ());
235242 List <Map <String , Object >> toolCallMaps = new ArrayList <>();
236243 for (int i = 0 ; i < assistantMessage .getToolCalls ().size (); i ++) {
237244 AssistantMessage .ToolCall tc = assistantMessage .getToolCalls ().get (i );
@@ -256,21 +263,23 @@ OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg),
256263 }
257264 statementBuilder .append ("""
258265 WITH msg
259- FOREACH(tr IN $toolResponses | CREATE (tm:%s )
266+ FOREACH(tr IN $toolResponses | CREATE (tm:$($toolResponseLabel) )
260267 SET tm = tr
261268 MERGE (msg)-[:HAS_TOOL_RESPONSE]->(tm))
262- """ . formatted ( this . config . getToolResponseLabel ()) );
269+ """ );
263270 queryParameters .put ("toolResponses" , toolResponseMaps );
271+ queryParameters .put ("toolResponseLabel" , this .config .getToolResponseLabel ());
264272 }
265273 if (message instanceof MediaContent messageWithMedia && !messageWithMedia .getMedia ().isEmpty ()) {
266274 List <Map <String , Object >> mediaNodes = convertMediaToMap (messageWithMedia .getMedia ());
267275 statementBuilder .append ("""
268276 WITH msg
269277 UNWIND $media AS m
270- CREATE (media:%s ) SET media = m
278+ CREATE (media:$($mediaLabel) ) SET media = m
271279 WITH msg, media CREATE (msg)-[:HAS_MEDIA]->(media)
272- """ . formatted ( this . config . getMediaLabel ()) );
280+ """ );
273281 queryParameters .put ("media" , mediaNodes );
282+ queryParameters .put ("mediaLabel" , this .config .getMediaLabel ());
274283 }
275284 t .run (statementBuilder .toString (), queryParameters );
276285 }
0 commit comments