From 818b8e4334ef4adcd09c514ca154699217baa390 Mon Sep 17 00:00:00 2001 From: Michael Simons Date: Wed, 7 May 2025 17:01:16 +0200 Subject: [PATCH] Polish `Neo4jChatMemoryRepository`. This change turns all the labels into parameters, avoiding the possibility of Cypher injection as the config does not do any sanitization. In addition, the interaction with the driver is changed so that it uses transactional functions, which are retried when any communication with the Neo4j DBMS fails. We can do this here as the repository is not subject to application wide transactions. An alternative to the parameters for labels would be using Cypher-DSL as we did in other parts of this project to sanitize labels proper. Signed-off-by: Michael Simons --- .../neo4j/Neo4jChatMemoryRepository.java | 145 ++++++++++-------- 1 file changed, 77 insertions(+), 68 deletions(-) diff --git a/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepository.java b/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepository.java index 7aa3ee29942..c914d61cbeb 100644 --- a/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepository.java +++ b/memory/spring-ai-model-chat-memory-neo4j/src/main/java/org/springframework/ai/chat/memory/neo4j/Neo4jChatMemoryRepository.java @@ -1,8 +1,8 @@ package org.springframework.ai.chat.memory.neo4j; -import org.neo4j.driver.Result; import org.neo4j.driver.Session; import org.neo4j.driver.Transaction; +import org.neo4j.driver.TransactionContext; import org.springframework.ai.chat.memory.ChatMemoryRepository; import org.springframework.ai.chat.messages.*; import org.springframework.ai.content.Media; @@ -11,15 +11,17 @@ import java.net.URI; import java.util.*; +import java.util.stream.Collectors; /** * An implementation of {@link ChatMemoryRepository} for Neo4J * * @author Enrico Rampazzo + * @author Michael J. Simons * @since 1.0.0 */ -public class Neo4jChatMemoryRepository implements ChatMemoryRepository { +public final class Neo4jChatMemoryRepository implements ChatMemoryRepository { private final Neo4jChatMemoryConfig config; @@ -29,63 +31,65 @@ public Neo4jChatMemoryRepository(Neo4jChatMemoryConfig config) { @Override public List findConversationIds() { - try (var session = config.getDriver().session()) { - return session.run("MATCH (conversation:%s) RETURN conversation.id".formatted(config.getSessionLabel())) - .stream() - .map(r -> r.get("conversation.id").asString()) - .toList(); - } + return config.getDriver() + .executableQuery("MATCH (conversation:$($sessionLabel)) RETURN conversation.id") + .withParameters(Map.of("sessionLabel", config.getSessionLabel())) + .execute(Collectors.mapping(r -> r.get("conversation.id").asString(), Collectors.toList())); } @Override public List findByConversationId(String conversationId) { - String statementBuilder = """ - MATCH (s:%s {id:$conversationId})-[r:HAS_MESSAGE]->(m:%s) + String statement = """ + MATCH (s:$($sessionLabel) {id:$conversationId})-[r:HAS_MESSAGE]->(m:$($messageLabel)) WITH m - OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s) - OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s) WITH m, metadata, media ORDER BY media.idx ASC - OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s) WITH m, metadata, media, tr ORDER BY tr.idx ASC - OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s) + OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:$($metadataLabel)) + OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:$($mediaLabel)) WITH m, metadata, media ORDER BY media.idx ASC + OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:$($toolResponseLabel)) WITH m, metadata, media, tr ORDER BY tr.idx ASC + OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:$($toolCallLabel)) WITH m, metadata, media, tr, tc ORDER BY tc.idx ASC RETURN m, metadata, collect(tr) as toolResponses, collect(tc) as toolCalls, collect(media) as medias ORDER BY m.idx ASC - """.formatted(this.config.getSessionLabel(), this.config.getMessageLabel(), - this.config.getMetadataLabel(), this.config.getMediaLabel(), this.config.getToolResponseLabel(), - this.config.getToolCallLabel()); - Result res = this.config.getDriver().session().run(statementBuilder, Map.of("conversationId", conversationId)); - return res.stream().map(record -> { - Map messageMap = record.get("m").asMap(); - String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString(); - Message message = null; - List mediaList = List.of(); - if (!record.get("medias").isNull()) { - mediaList = getMedia(record); - } - if (msgType.equals(MessageType.USER.getValue())) { - message = buildUserMessage(record, messageMap, mediaList); - } - if (msgType.equals(MessageType.ASSISTANT.getValue())) { - message = buildAssistantMessage(record, messageMap, mediaList); - } - if (msgType.equals(MessageType.SYSTEM.getValue())) { - SystemMessage.Builder systemMessageBuilder = SystemMessage.builder() - .text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString()); - if (!record.get("metadata").isNull()) { - Map retrievedMetadata = record.get("metadata").asMap(); - systemMessageBuilder.metadata(retrievedMetadata); + """; + + return this.config.getDriver() + .executableQuery(statement) + .withParameters(Map.of("conversationId", conversationId, "sessionLabel", this.config.getSessionLabel(), + "messageLabel", this.config.getMessageLabel(), "metadataLabel", this.config.getMetadataLabel(), + "mediaLabel", this.config.getMediaLabel(), "toolResponseLabel", this.config.getToolResponseLabel(), + "toolCallLabel", this.config.getToolCallLabel())) + .execute(Collectors.mapping(record -> { + Map messageMap = record.get("m").asMap(); + String msgType = messageMap.get(MessageAttributes.MESSAGE_TYPE.getValue()).toString(); + Message message = null; + List mediaList = List.of(); + if (!record.get("medias").isNull()) { + mediaList = getMedia(record); } - message = systemMessageBuilder.build(); - } - if (msgType.equals(MessageType.TOOL.getValue())) { - message = buildToolMessage(record); - } - if (message == null) { - throw new IllegalArgumentException("%s messages are not supported" - .formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString())); - } - message.getMetadata().put("messageType", message.getMessageType()); - return message; - }).toList(); + if (msgType.equals(MessageType.USER.getValue())) { + message = buildUserMessage(record, messageMap, mediaList); + } + if (msgType.equals(MessageType.ASSISTANT.getValue())) { + message = buildAssistantMessage(record, messageMap, mediaList); + } + if (msgType.equals(MessageType.SYSTEM.getValue())) { + SystemMessage.Builder systemMessageBuilder = SystemMessage.builder() + .text(messageMap.get(MessageAttributes.TEXT_CONTENT.getValue()).toString()); + if (!record.get("metadata").isNull()) { + Map retrievedMetadata = record.get("metadata").asMap(); + systemMessageBuilder.metadata(retrievedMetadata); + } + message = systemMessageBuilder.build(); + } + if (msgType.equals(MessageType.TOOL.getValue())) { + message = buildToolMessage(record); + } + if (message == null) { + throw new IllegalArgumentException("%s messages are not supported" + .formatted(record.get(MessageAttributes.MESSAGE_TYPE.getValue()).asString())); + } + message.getMetadata().put("messageType", message.getMessageType()); + return message; + }, Collectors.toList())); } @@ -96,12 +100,11 @@ public void saveAll(String conversationId, List messages) { // Then add the new messages try (Session s = this.config.getDriver().session()) { - try (Transaction t = s.beginTransaction()) { + s.executeWriteWithoutResult(tx -> { for (Message m : messages) { - addMessageToTransaction(t, conversationId, m); + addMessageToTransaction(tx, conversationId, m); } - t.commit(); - } + }); } } @@ -196,42 +199,46 @@ else if (mediaMap.get(MediaAttributes.DATA.getValue()).getClass().isArray()) { return mediaList; } - private void addMessageToTransaction(Transaction t, String conversationId, Message message) { + private void addMessageToTransaction(TransactionContext t, String conversationId, Message message) { Map queryParameters = new HashMap<>(); queryParameters.put("conversationId", conversationId); StringBuilder statementBuilder = new StringBuilder(); statementBuilder.append(""" - MERGE (s:%s {id:$conversationId}) WITH s - OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), 0) as totalMsg, s - CREATE (s)-[:HAS_MESSAGE]->(msg:%s) SET msg = $messageProperties + MERGE (s:$($sessionLabel) {id:$conversationId}) WITH s + OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:$($messageLabel)) + WITH coalesce(count(countMsg), 0) as totalMsg, s + CREATE (s)-[:HAS_MESSAGE]->(msg:$($messageLabel)) SET msg = $messageProperties SET msg.idx = totalMsg + 1 - """.formatted(this.config.getSessionLabel(), this.config.getMessageLabel(), - this.config.getMessageLabel())); + """); Map attributes = new HashMap<>(); attributes.put(MessageAttributes.MESSAGE_TYPE.getValue(), message.getMessageType().getValue()); attributes.put(MessageAttributes.TEXT_CONTENT.getValue(), message.getText()); attributes.put("id", UUID.randomUUID().toString()); queryParameters.put("messageProperties", attributes); + queryParameters.put("sessionLabel", this.config.getSessionLabel()); + queryParameters.put("messageLabel", this.config.getMessageLabel()); if (!Optional.ofNullable(message.getMetadata()).orElse(Map.of()).isEmpty()) { statementBuilder.append(""" WITH msg - CREATE (metadataNode:%s) + CREATE (metadataNode:$($metadataLabel)) CREATE (msg)-[:HAS_METADATA]->(metadataNode) SET metadataNode = $metadata - """.formatted(this.config.getMetadataLabel())); + """); Map metadataCopy = new HashMap<>(message.getMetadata()); metadataCopy.remove("messageType"); queryParameters.put("metadata", metadataCopy); + queryParameters.put("metadataLabel", this.config.getMetadataLabel()); } if (message instanceof AssistantMessage assistantMessage) { if (assistantMessage.hasToolCalls()) { statementBuilder.append(""" WITH msg - FOREACH(tc in $toolCalls | CREATE (toolCall:%s) SET toolCall = tc + FOREACH(tc in $toolCalls | CREATE (toolCall:$($toolLabel)) SET toolCall = tc CREATE (msg)-[:HAS_TOOL_CALL]->(toolCall)) - """.formatted(this.config.getToolCallLabel())); + """); + queryParameters.put("toolLabel", this.config.getToolCallLabel()); List> toolCallMaps = new ArrayList<>(); for (int i = 0; i < assistantMessage.getToolCalls().size(); i++) { AssistantMessage.ToolCall tc = assistantMessage.getToolCalls().get(i); @@ -256,21 +263,23 @@ OPTIONAL MATCH (s)-[:HAS_MESSAGE]->(countMsg:%s) WITH coalesce(count(countMsg), } statementBuilder.append(""" WITH msg - FOREACH(tr IN $toolResponses | CREATE (tm:%s) + FOREACH(tr IN $toolResponses | CREATE (tm:$($toolResponseLabel)) SET tm = tr MERGE (msg)-[:HAS_TOOL_RESPONSE]->(tm)) - """.formatted(this.config.getToolResponseLabel())); + """); queryParameters.put("toolResponses", toolResponseMaps); + queryParameters.put("toolResponseLabel", this.config.getToolResponseLabel()); } if (message instanceof MediaContent messageWithMedia && !messageWithMedia.getMedia().isEmpty()) { List> mediaNodes = convertMediaToMap(messageWithMedia.getMedia()); statementBuilder.append(""" WITH msg UNWIND $media AS m - CREATE (media:%s) SET media = m + CREATE (media:$($mediaLabel)) SET media = m WITH msg, media CREATE (msg)-[:HAS_MEDIA]->(media) - """.formatted(this.config.getMediaLabel())); + """); queryParameters.put("media", mediaNodes); + queryParameters.put("mediaLabel", this.config.getMediaLabel()); } t.run(statementBuilder.toString(), queryParameters); }