Skip to content

Commit 10089b1

Browse files
author
Enrico Rampazzo
committed
chatmemory implementation
Signed-off-by: Enrico Rampazzo <[email protected]>
1 parent 329e6c0 commit 10089b1

File tree

12 files changed

+878
-1
lines changed

12 files changed

+878
-1
lines changed

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ Refer to the xref:_retrieval_augmented_generation[Retrieval Augmented Generation
374374

375375
The interface `ChatMemory` represents a storage for chat conversation history. It provides methods to add messages to a conversation, retrieve messages from a conversation, and clear the conversation history.
376376

377-
There are currently two implementations, `InMemoryChatMemory` and `CassandraChatMemory`, that provide storage for chat conversation history, in-memory and persisted with `time-to-live`, correspondingly.
377+
There are currently three implementations, `InMemoryChatMemory`, `CassandraChatMemory` and `Neo4jChatMemory`, that provide storage for chat conversation history, in-memory, persisted with `time-to-live` in Cassandra, and persisted without `time-to-live` in Neo4j correspondingly.
378378

379379
To create a `CassandraChatMemory` with `time-to-live`:
380380

@@ -383,6 +383,24 @@ To create a `CassandraChatMemory` with `time-to-live`:
383383
CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build());
384384
----
385385

386+
The Neo4j chat memory supports the following configuration parameters:
387+
388+
[cols="2,5,1",stripes=even]
389+
|===
390+
|Property | Description | Default Value
391+
392+
| `spring.ai.chat.memory.neo4j.messageLabel` | The label for the nodes that store messages | `Message`
393+
| `spring.ai.chat.memory.neo4j.sessionLabel` | The label for the nodes that store conversation sessions | `Session`
394+
| `spring.ai.chat.memory.neo4j.toolCallLabel` | The label for nodes that store tool calls, for example
395+
in Assistant Messages | `ToolCall`
396+
| `spring.ai.chat.memory.neo4j.metadataLabel` | The label for the node that store a message metadata | `Metadata`
397+
| `spring.ai.chat.memory.neo4j.toolResponseLabel` | The label for the nodes that store tool responses | `ToolResponse`
398+
| `spring.ai.chat.memory.neo4j.mediaLabel` | The label for the nodes that store the media associated to a message | `ToolResponse`
399+
400+
401+
|===
402+
403+
386404
The following advisor implementations use the `ChatMemory` interface to advice the prompt with conversation history which differ in the details of how the memory is added to the prompt
387405

388406
* `MessageChatMemoryAdvisor` : Memory is retrieved and added as a collection of messages to the prompt
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.autoconfigure.chat.memory.neo4j;
18+
19+
import org.neo4j.driver.Driver;
20+
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory;
21+
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
22+
import org.springframework.boot.autoconfigure.AutoConfiguration;
23+
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
24+
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
25+
import org.springframework.boot.autoconfigure.neo4j.Neo4jAutoConfiguration;
26+
import org.springframework.boot.context.properties.EnableConfigurationProperties;
27+
import org.springframework.context.annotation.Bean;
28+
29+
/**
30+
* {@link AutoConfiguration Auto-configuration} for {@link Neo4jChatMemory}.
31+
*
32+
* @author Enrico Rampazzo
33+
* @since 1.0.0
34+
*/
35+
@AutoConfiguration(after = Neo4jAutoConfiguration.class)
36+
@ConditionalOnClass({ Neo4jChatMemory.class, Driver.class })
37+
@EnableConfigurationProperties(Neo4jChatMemoryProperties.class)
38+
public class Neo4jChatMemoryAutoConfiguration {
39+
40+
@Bean
41+
@ConditionalOnMissingBean
42+
public Neo4jChatMemory chatMemory(Neo4jChatMemoryProperties properties, Driver driver) {
43+
44+
var builder = Neo4jChatMemoryConfig.builder().withMediaLabel(properties.getMediaLabel())
45+
.withMessageLabel(properties.getMessageLabel()).withMetadataLabel(properties.getMetadataLabel())
46+
.withSessionLabel(properties.getSessionLabel()).withToolCallLabel(properties.getToolCallLabel())
47+
.withToolResponseLabel(properties.getToolResponseLabel())
48+
.withDriver(driver);
49+
50+
return Neo4jChatMemory.create(builder.build());
51+
}
52+
53+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.autoconfigure.chat.memory.neo4j;
18+
19+
import org.springframework.ai.autoconfigure.chat.memory.CommonChatMemoryProperties;
20+
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
21+
import org.springframework.boot.context.properties.ConfigurationProperties;
22+
23+
/**
24+
* Configuration properties for Neo4j chat memory.
25+
*
26+
* @author Enrico Rampazzo
27+
*/
28+
@ConfigurationProperties(Neo4jChatMemoryProperties.CONFIG_PREFIX)
29+
public class Neo4jChatMemoryProperties {
30+
31+
public static final String CONFIG_PREFIX = "spring.ai.chat.memory.neo4j";
32+
private String sessionLabel = Neo4jChatMemoryConfig.DEFAULT_SESSION_LABEL;
33+
private String toolCallLabel = Neo4jChatMemoryConfig.DEFAULT_TOOL_CALL_LABEL;
34+
private String metadataLabel = Neo4jChatMemoryConfig.DEFAULT_METADATA_LABEL;
35+
private String messageLabel = Neo4jChatMemoryConfig.DEFAULT_MESSAGE_LABEL;
36+
private String toolResponseLabel = Neo4jChatMemoryConfig.DEFAULT_TOOL_RESPONSE_LABEL;
37+
private String mediaLabel = Neo4jChatMemoryConfig.DEFAULT_MEDIA_LABEL;
38+
39+
public String getSessionLabel() {
40+
return sessionLabel;
41+
}
42+
43+
public void setSessionLabel(String sessionLabel) {
44+
this.sessionLabel = sessionLabel;
45+
}
46+
47+
public String getToolCallLabel() {
48+
return toolCallLabel;
49+
}
50+
51+
public String getMetadataLabel() {
52+
return metadataLabel;
53+
}
54+
55+
public String getMessageLabel() {
56+
return messageLabel;
57+
}
58+
59+
public String getToolResponseLabel() {
60+
return toolResponseLabel;
61+
}
62+
63+
public String getMediaLabel() {
64+
return mediaLabel;
65+
}
66+
67+
public void setToolCallLabel(String toolCallLabel) {
68+
this.toolCallLabel = toolCallLabel;
69+
}
70+
71+
public void setMetadataLabel(String metadataLabel) {
72+
this.metadataLabel = metadataLabel;
73+
}
74+
75+
public void setMessageLabel(String messageLabel) {
76+
this.messageLabel = messageLabel;
77+
}
78+
79+
public void setToolResponseLabel(String toolResponseLabel) {
80+
this.toolResponseLabel = toolResponseLabel;
81+
}
82+
83+
public void setMediaLabel(String mediaLabel) {
84+
this.mediaLabel = mediaLabel;
85+
}
86+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.autoconfigure.chat.memory.neo4j;
18+
19+
import com.datastax.driver.core.utils.UUIDs;
20+
import org.junit.jupiter.api.Test;
21+
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory;
22+
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
23+
import org.springframework.ai.chat.messages.*;
24+
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
25+
import org.springframework.ai.model.Media;
26+
import org.springframework.boot.autoconfigure.AutoConfigurations;
27+
import org.springframework.boot.autoconfigure.neo4j.Neo4jAutoConfiguration;
28+
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
29+
import org.springframework.util.MimeType;
30+
import org.testcontainers.containers.Neo4jContainer;
31+
import org.testcontainers.junit.jupiter.Container;
32+
import org.testcontainers.junit.jupiter.Testcontainers;
33+
import org.testcontainers.utility.DockerImageName;
34+
35+
import java.net.URI;
36+
import java.nio.charset.StandardCharsets;
37+
import java.util.List;
38+
import java.util.Map;
39+
40+
import static org.assertj.core.api.Assertions.assertThat;
41+
42+
/**
43+
* @author Mick Semb Wever
44+
* @author Jihoon Kim
45+
* @since 1.0.0
46+
*/
47+
@Testcontainers
48+
class Neo4jChatMemoryAutoConfigurationIT {
49+
50+
static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("neo4j");
51+
52+
@SuppressWarnings({"rawtypes", "resource"})
53+
@Container
54+
static Neo4jContainer neo4jContainer = (Neo4jContainer) new Neo4jContainer(DEFAULT_IMAGE_NAME.withTag("5")).withoutAuthentication().withExposedPorts(7474,7687);
55+
56+
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
57+
.withConfiguration(
58+
AutoConfigurations.of(Neo4jChatMemoryAutoConfiguration.class, Neo4jAutoConfiguration.class));
59+
60+
61+
@Test
62+
void addAndGet() {
63+
this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl())
64+
.run(context -> {
65+
Neo4jChatMemory memory = context.getBean(Neo4jChatMemory.class);
66+
67+
String sessionId = UUIDs.timeBased().toString();
68+
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();
69+
70+
UserMessage userMessage = new UserMessage("test question");
71+
72+
73+
memory.add(sessionId, userMessage);
74+
List<Message> messages = memory.get(sessionId, Integer.MAX_VALUE);
75+
assertThat(messages).hasSize(1);
76+
assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(userMessage);
77+
78+
memory.clear(sessionId);
79+
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();
80+
81+
AssistantMessage assistantMessage = new AssistantMessage("test answer", Map.of(),
82+
List.of(new AssistantMessage.ToolCall(
83+
"id", "type", "name", "arguments")));
84+
85+
memory.add(sessionId, List.of(userMessage, assistantMessage));
86+
messages = memory.get(sessionId, Integer.MAX_VALUE);
87+
assertThat(messages).hasSize(2);
88+
assertThat(messages.get(1)).isEqualTo(userMessage);
89+
90+
assertThat(messages.get(0)).isEqualTo(assistantMessage);
91+
memory.clear(sessionId);
92+
MimeType textPlain = MimeType.valueOf("text/plain");
93+
List<Media> media = List.of(Media.builder().name("some media").id(UUIDs.random().toString())
94+
.mimeType(textPlain).data("hello".getBytes(StandardCharsets.UTF_8)).build(),
95+
Media.builder().data(URI.create("http://www.google.com").toURL()).mimeType(textPlain).build());
96+
UserMessage userMessageWithMedia = new UserMessage("Message with media", media);
97+
memory.add(sessionId, userMessageWithMedia);
98+
99+
messages = memory.get(sessionId, Integer.MAX_VALUE);
100+
assertThat(messages.size()).isEqualTo(1);
101+
assertThat(messages.get(0)).isEqualTo(userMessageWithMedia);
102+
assertThat(((UserMessage)messages.get(0)).getMedia()).hasSize(2);
103+
assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator().isEqualTo(media);
104+
memory.clear(sessionId);
105+
ToolResponseMessage toolResponseMessage = new ToolResponseMessage(List.of(
106+
new ToolResponse("id", "name", "responseData"),
107+
new ToolResponse("id2", "name2", "responseData2")),
108+
Map.of("id", "id", "metadataKey", "metadata"));
109+
memory.add(sessionId, toolResponseMessage);
110+
messages = memory.get(sessionId, Integer.MAX_VALUE);
111+
assertThat(messages.size()).isEqualTo(1);
112+
assertThat(messages.get(0)).isEqualTo(toolResponseMessage);
113+
114+
memory.clear(sessionId);
115+
SystemMessage sm = new SystemMessage("this is a System message");
116+
memory.add(sessionId, sm);
117+
messages = memory.get(sessionId, Integer.MAX_VALUE);
118+
assertThat(messages).hasSize(1);
119+
assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(sm);
120+
});
121+
}
122+
@Test
123+
void setCustomConfiguration(){
124+
final String sessionLabel = "LabelSession";
125+
final String toolCallLabel = "LabelToolCall";
126+
final String metadataLabel = "LabelMetadata";
127+
final String messageLabel = "LabelMessage";
128+
final String toolResponseLabel = "LabelToolResponse";
129+
final String mediaLabel = "LabelMedia";
130+
131+
final String propertyBase = "spring.ai.chat.memory.neo4j.%s=%s";
132+
this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl(),
133+
propertyBase.formatted("sessionlabel", sessionLabel),
134+
propertyBase.formatted("toolcallLabel", toolCallLabel),
135+
propertyBase.formatted("metadatalabel", metadataLabel),
136+
propertyBase.formatted("messagelabel", messageLabel),
137+
propertyBase.formatted("toolresponselabel", toolResponseLabel),
138+
propertyBase.formatted("medialabel", mediaLabel))
139+
.run(context -> {
140+
Neo4jChatMemory chatMemory = context.getBean(Neo4jChatMemory.class);
141+
Neo4jChatMemoryConfig config = chatMemory.getConfig();
142+
assertThat(config.getMessageLabel()).isEqualTo(messageLabel);
143+
assertThat(config.getMediaLabel()).isEqualTo(mediaLabel);
144+
assertThat(config.getMetadataLabel()).isEqualTo(metadataLabel);
145+
assertThat(config.getSessionLabel()).isEqualTo(sessionLabel);
146+
assertThat(config.getToolResponseLabel()).isEqualTo(toolResponseLabel);
147+
assertThat(config.getToolCallLabel()).isEqualTo(toolCallLabel);
148+
});
149+
}
150+
151+
152+
153+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.autoconfigure.chat.memory.neo4j;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.springframework.ai.autoconfigure.chat.memory.cassandra.CassandraChatMemoryProperties;
21+
import org.springframework.ai.chat.memory.cassandra.CassandraChatMemoryConfig;
22+
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
23+
24+
import java.time.Duration;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
28+
/**
29+
* @author Enrico Rampazzo
30+
* @since 1.0.0
31+
*/
32+
class Neo4jChatMemoryPropertiesTest {
33+
34+
@Test
35+
void defaultValues() {
36+
var props = new Neo4jChatMemoryProperties();
37+
assertThat(props.getMediaLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_MEDIA_LABEL);
38+
assertThat(props.getMessageLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_MESSAGE_LABEL);
39+
assertThat(props.getMetadataLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_METADATA_LABEL);
40+
assertThat(props.getSessionLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_SESSION_LABEL);
41+
assertThat(props.getToolCallLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_TOOL_CALL_LABEL);
42+
assertThat(props.getToolResponseLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_TOOL_RESPONSE_LABEL);
43+
}
44+
45+
@Test
46+
void customValues() {
47+
var props = new CassandraChatMemoryProperties();
48+
props.setKeyspace("my_keyspace");
49+
props.setTable("my_table");
50+
props.setAssistantColumn("my_assistant_column");
51+
props.setUserColumn("my_user_column");
52+
props.setTimeToLive(Duration.ofDays(1));
53+
props.setInitializeSchema(false);
54+
55+
assertThat(props.getKeyspace()).isEqualTo("my_keyspace");
56+
assertThat(props.getTable()).isEqualTo("my_table");
57+
assertThat(props.getAssistantColumn()).isEqualTo("my_assistant_column");
58+
assertThat(props.getUserColumn()).isEqualTo("my_user_column");
59+
assertThat(props.getTimeToLive()).isEqualTo(Duration.ofDays(1));
60+
assertThat(props.isInitializeSchema()).isFalse();
61+
}
62+
63+
}

0 commit comments

Comments
 (0)