diff --git a/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepository.java b/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepository.java index 7aa3ee29942..c914d61cbeb 100644 --- a/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepository.java +++ b/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepository.java @@ -1,8 +1,8 @@ package org.springframework.ai.chat.memory.neo4j; -import org.neo4j.driver.Result; import org.neo4j.driver.Session; import org.neo4j.driver.Transaction; +import org.neo4j.driver.TransactionContext; import org.springframework.ai.chat.memory.ChatMemoryRepository; import org.springframework.ai.chat.messages.*; import org.springframework.ai.content.Media; @@ -11,15 +11,17 @@ import java.net.URI; import java.util.*; +import java.util.stream.Collectors; /** * An implementation of {@link ChatMemoryRepository} for Neo4J * * @author Enrico Rampazzo + * @author Michael J. Simons * @since 1.0.0 */ -public class Neo4jChatMemoryRepository implements ChatMemoryRepository { +public final class Neo4jChatMemoryRepository implements ChatMemoryRepository { private final Neo4jChatMemoryConfig config; @@ -29,63 +31,65 @@ public Neo4jChatMemoryRepository(Neo4jChatMemoryConfig config) { @Override public List findConversationIds() { - try (var session = config.getDriver().session()) { - return session.run("MATCH (conversation:%s) RETURN conversation.id".formatted(config.getSessionLabel())) - .stream() - .map(r -> r.get("conversation.id").asString()) - .toList(); - } + return config.getDriver() + .executableQuery("MATCH (conversation:$($sessionLabel)) RETURN conversation.id") + .withParameters(Map.of("sessionLabel", config.getSessionLabel())) + .execute(Collectors.mapping(r -> r.get("conversation.id").asString(), Collectors.toList())); } @Override public List findByConversationId(String conversationId) { - String statementBuilder = """ - MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s) + String statement = """ + MATCH (s:$($sessionLabel) {id:$conversationId})-[r:HAS_MESSAGE]->(m:$($messageLabel)) WITH m - OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s) - OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s) WITH m, metadata, media ORDER BY media.idx ASC - OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s) WITH m, metadata, media, tr ORDER BY tr.idx ASC - OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s) + OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:$($metadataLabel)) + OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:$($mediaLabel)) WITH m, metadata, media ORDER BY media.idx ASC + OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:$($toolResponseLabel)) WITH m, metadata, media, tr ORDER BY tr.idx ASC + OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:$($toolCallLabel)) WITH m, metadata, media, tr, tc ORDER BY tc.idx ASC RETURN m, metadata, collect(tr) as toolResponses, collect(tc) as toolCalls, collect(media) as medias ORDER BY m.idx ASC - """.formatted(this.config.getSessionLabel(), this.config.getMessageLabel(), - this.config.getMetadataLabel(), this.config.getMediaLabel(), this.config.getToolResponseLabel(), - this.config.getToolCallLabel()); - Result res = this.config.getDriver().session().run(statementBuilder, Map.of("conversationId", conversationId)); - return res.stream().map(record -> { - Map messageMap = record.get("m").asMap(); - String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString(); - Message message = null; - List mediaList = List.of(); - if (!record.get("medias").isNull()) { - mediaList = getMedia(record); - } - if (msgType.equals(MessageType.USER.getValue())) { - message = buildUserMessage(record, messageMap, mediaList); - } - if (msgType.equals(MessageType.ASSISTANT.getValue())) { - message = buildAssistantMessage(record, messageMap, mediaList); - } - if (msgType.equals(MessageType.SYSTEM.getValue())) { - SystemMessage.Builder systemMessageBuilder = SystemMessage.builder() - .text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString()); - if (!record.get("metadata").isNull()) { - Map retrievedMetadata = record.get("metadata").asMap(); - systemMessageBuilder.metadata(retrievedMetadata); + """; + + return this.config.getDriver() + .executableQuery(statement) + .withParameters(Map.of("conversationId", conversationId, "sessionLabel", this.config.getSessionLabel(), + "messageLabel", this.config.getMessageLabel(), "metadataLabel", this.config.getMetadataLabel(), + "mediaLabel", this.config.getMediaLabel(), "toolResponseLabel", this.config.getToolResponseLabel(), + "toolCallLabel", this.config.getToolCallLabel())) + .execute(Collectors.mapping(record -> { + Map messageMap = record.get("m").asMap(); + String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString(); + Message message = null; + List mediaList = List.of(); + if (!record.get("medias").isNull()) { + mediaList = getMedia(record); } - message = systemMessageBuilder.build(); - } - if (msgType.equals(MessageType.TOOL.getValue())) { - message = buildToolMessage(record); - } - if (message == null) { - throw new IllegalArgumentException("%s messages are not supported" - .formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString())); - } - message.getMetadata().put("messageType", message.getMessageType()); - return message; - }).toList(); + if (msgType.equals(MessageType.USER.getValue())) { + message = buildUserMessage(record, messageMap, mediaList); + } + if (msgType.equals(MessageType.ASSISTANT.getValue())) { + message = buildAssistantMessage(record, messageMap, mediaList); + } + if (msgType.equals(MessageType.SYSTEM.getValue())) { + SystemMessage.Builder systemMessageBuilder = SystemMessage.builder() + .text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString()); + if (!record.get("metadata").isNull()) { + Map retrievedMetadata = record.get("metadata").asMap(); + systemMessageBuilder.metadata(retrievedMetadata); + } + message = systemMessageBuilder.build(); + } + if (msgType.equals(MessageType.TOOL.getValue())) { + message = buildToolMessage(record); + } + if (message == null) { + throw new IllegalArgumentException("%s messages are not supported" + .formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString())); + } + message.getMetadata().put("messageType", message.getMessageType()); + return message; + }, Collectors.toList())); } @@ -96,12 +100,11 @@ public void saveAll(String conversationId, List messages) { // Then add the new messages try (Session s = this.config.getDriver().session()) { - try (Transaction t = s.beginTransaction()) { + s.executeWriteWithoutResult(tx -> { for (Message m : messages) { - addMessageToTransaction(t, conversationId, m); + addMessageToTransaction(tx, conversationId, m); } - t.commit(); - } + }); } } @@ -196,42 +199,46 @@ else if (mediaMap.get(MediaAttributes.DATA.getValue()).getClass().isArray()) { return mediaList; } - private void addMessageToTransaction(Transaction t, String conversationId, Message message) { + private void addMessageToTransaction(TransactionContext t, String conversationId, Message message) { Map queryParameters = new HashMap<>(); queryParameters.put("conversationId", conversationId); StringBuilder statementBuilder = new StringBuilder(); statementBuilder.append(""" - MERGE (s:%s {id:$conversationId}) WITH s - OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), 0) as totalMsg, s - CREATE (s)-[:HAS_MESSAGE]->(msg:%s) SET msg = $messageProperties + MERGE (s:$($sessionLabel) {id:$conversationId}) WITH s + OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:$($messageLabel)) + WITH coalesce(count(countMsg), 0) as totalMsg, s + CREATE (s)-[:HAS_MESSAGE]->(msg:$($messageLabel)) SET msg = $messageProperties SET msg.idx = totalMsg + 1 - """.formatted(this.config.getSessionLabel(), this.config.getMessageLabel(), - this.config.getMessageLabel())); + """); Map attributes = new HashMap<>(); attributes.put(MessageAttributes.MESSAGE_TYPE.getValue(), message.getMessageType().getValue()); attributes.put(MessageAttributes.TEXT_CONTENT.getValue(), message.getText()); attributes.put("id", UUID.randomUUID().toString()); queryParameters.put("messageProperties", attributes); + queryParameters.put("sessionLabel", this.config.getSessionLabel()); + queryParameters.put("messageLabel", this.config.getMessageLabel()); if (!Optional.ofNullable(message.getMetadata()).orElse(Map.of()).isEmpty()) { statementBuilder.append(""" WITH msg - CREATE (metadataNode:%s) + CREATE (metadataNode:$($metadataLabel)) CREATE (msg)-[:HAS_METADATA]->(metadataNode) SET metadataNode = $metadata - """.formatted(this.config.getMetadataLabel())); + """); Map metadataCopy = new HashMap<>(message.getMetadata()); metadataCopy.remove("messageType"); queryParameters.put("metadata", metadataCopy); + queryParameters.put("metadataLabel", this.config.getMetadataLabel()); } if (message instanceof AssistantMessage assistantMessage) { if (assistantMessage.hasToolCalls()) { statementBuilder.append(""" WITH msg - FOREACH(tc in $toolCalls | CREATE (toolCall:%s) SET toolCall = tc + FOREACH(tc in $toolCalls | CREATE (toolCall:$($toolLabel)) SET toolCall = tc CREATE (msg)-[:HAS_TOOL_CALL]->(toolCall)) - """.formatted(this.config.getToolCallLabel())); + """); + queryParameters.put("toolLabel", this.config.getToolCallLabel()); List> toolCallMaps = new ArrayList<>(); for (int i = 0; i < assistantMessage.getToolCalls().size(); i++) { AssistantMessage.ToolCall tc = assistantMessage.getToolCalls().get(i); @@ -256,21 +263,23 @@ OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), } statementBuilder.append(""" WITH msg - FOREACH(tr IN $toolResponses | CREATE (tm:%s) + FOREACH(tr IN $toolResponses | CREATE (tm:$($toolResponseLabel)) SET tm = tr MERGE (msg)-[:HAS_TOOL_RESPONSE]->(tm)) - """.formatted(this.config.getToolResponseLabel())); + """); queryParameters.put("toolResponses", toolResponseMaps); + queryParameters.put("toolResponseLabel", this.config.getToolResponseLabel()); } if (message instanceof MediaContent messageWithMedia && !messageWithMedia.getMedia().isEmpty()) { List> mediaNodes = convertMediaToMap(messageWithMedia.getMedia()); statementBuilder.append(""" WITH msg UNWIND $media AS m - CREATE (media:%s) SET media = m + CREATE (media:$($mediaLabel)) SET media = m WITH msg, media CREATE (msg)-[:HAS_MEDIA]->(media) - """.formatted(this.config.getMediaLabel())); + """); queryParameters.put("media", mediaNodes); + queryParameters.put("mediaLabel", this.config.getMediaLabel()); } t.run(statementBuilder.toString(), queryParameters); }