Skip to content

Commit 57b2cd3

Browse files
committed
wip
Signed-off-by: Josh Long <[email protected]>
1 parent 228ef10 commit 57b2cd3

File tree

6 files changed

+209
-7
lines changed

6 files changed

+209
-7
lines changed

spring-ai-model/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@
6464
<optional>true</optional>
6565
</dependency>
6666

67+
<dependency>
68+
<groupId>org.springframework.security</groupId>
69+
<artifactId>spring-security-config</artifactId>
70+
<optional>true</optional>
71+
</dependency>
6772

6873
<dependency>
6974
<groupId>org.springframework</groupId>

spring-ai-model/src/main/java/org/springframework/ai/chat/memory/InMemoryChatMemoryRepository.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@
2828
* An in-memory implementation of {@link ChatMemoryRepository}.
2929
*
3030
* @author Thomas Vitale
31+
* @author Josh Long
3132
* @since 1.0.0
3233
*/
33-
public final class InMemoryChatMemoryRepository implements ChatMemoryRepository {
34+
public class InMemoryChatMemoryRepository implements ChatMemoryRepository {
3435

35-
Map<String, List<Message>> chatMemoryStore = new ConcurrentHashMap<>();
36+
private final Map<String, List<Message>> chatMemoryStore = new ConcurrentHashMap<>();
3637

3738
@Override
3839
public List<String> findConversationIds() {

spring-ai-model/src/main/java/org/springframework/ai/chat/memory/MessageWindowChatMemory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@
3636
* {@link SystemMessage} messages are preserved while evicting other types of messages.
3737
*
3838
* @author Thomas Vitale
39+
* @author Josh Long
3940
* @since 1.0.0
4041
*/
41-
public final class MessageWindowChatMemory implements ChatMemory {
42+
public class MessageWindowChatMemory implements ChatMemory {
4243

4344
private static final int DEFAULT_MAX_MESSAGES = 200;
4445

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package org.springframework.ai.chat.memory.encryption;
2+
3+
import org.jetbrains.annotations.NotNull;
4+
import org.springframework.ai.chat.memory.ChatMemoryRepository;
5+
import org.springframework.beans.BeansException;
6+
import org.springframework.beans.factory.config.BeanPostProcessor;
7+
import org.springframework.security.crypto.encrypt.TextEncryptor;
8+
9+
/**
10+
* Uses a configured {@link TextEncryptor text encryptor} to encrypt values before writes,
11+
* and decode those values from the read operations.
12+
*
13+
* @author Josh Long
14+
*/
15+
public class EncryptingChatMemoryBeanPostProcessor implements BeanPostProcessor {
16+
17+
private final TextEncryptor encryptor;
18+
19+
public EncryptingChatMemoryBeanPostProcessor(TextEncryptor encryptor) {
20+
this.encryptor = encryptor;
21+
}
22+
23+
@Override
24+
public Object postProcessAfterInitialization(@NotNull Object bean, @NotNull String beanName) throws BeansException {
25+
26+
if (bean instanceof ChatMemoryRepository cmr && !(cmr instanceof EncryptingChatMemoryRepository)) {
27+
return new EncryptingChatMemoryRepository(cmr, encryptor);
28+
}
29+
return BeanPostProcessor.super.postProcessAfterInitialization(bean, beanName);
30+
}
31+
32+
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package org.springframework.ai.chat.memory.encryption;
2+
3+
import org.jetbrains.annotations.NotNull;
4+
import org.springframework.ai.chat.memory.ChatMemoryRepository;
5+
import org.springframework.ai.chat.messages.AssistantMessage;
6+
import org.springframework.ai.chat.messages.Message;
7+
import org.springframework.ai.chat.messages.SystemMessage;
8+
import org.springframework.ai.chat.messages.UserMessage;
9+
import org.springframework.security.crypto.encrypt.TextEncryptor;
10+
11+
import java.util.List;
12+
import java.util.function.Function;
13+
import java.util.stream.Collectors;
14+
15+
/**
16+
*
17+
* Wraps {@link ChatMemoryRepository a ChatMemoryRepository}, encrypting and decrypting
18+
* reads and writes respectively using a Spring Security {@link TextEncryptor text
19+
* encryptor}.
20+
*
21+
* @author Josh Long
22+
*/
23+
public class EncryptingChatMemoryRepository implements ChatMemoryRepository {
24+
25+
private final ChatMemoryRepository target;
26+
27+
private final TextEncryptor textEncryptor;
28+
29+
public EncryptingChatMemoryRepository(ChatMemoryRepository target, TextEncryptor textEncryptor) {
30+
this.target = target;
31+
this.textEncryptor = textEncryptor;
32+
}
33+
34+
private Message transform(Message message, Function<String, String> function) {
35+
36+
var transformedText = function.apply(message.getText());
37+
38+
// todo is there a case to be made that we should seal the message hierarchy?
39+
if (message instanceof SystemMessage systemMessage) {
40+
return systemMessage.mutate().text(transformedText).build();
41+
}
42+
43+
if (message instanceof UserMessage userMessage) {
44+
return userMessage.mutate().text(transformedText).build();
45+
}
46+
47+
if (message instanceof AssistantMessage assistantMessage) {
48+
return assistantMessage.mutate().text(transformedText).build();
49+
}
50+
51+
return message;
52+
}
53+
54+
private Message decrypt(Message message) {
55+
return this.transform(message, this.textEncryptor::decrypt);
56+
}
57+
58+
private Message encrypt(Message message) {
59+
return this.transform(message, this.textEncryptor::encrypt);
60+
}
61+
62+
@NotNull
63+
@Override
64+
public List<String> findConversationIds() {
65+
return this.target.findConversationIds();
66+
}
67+
68+
@NotNull
69+
@Override
70+
public List<Message> findByConversationId(@NotNull String conversationId) {
71+
return this.target.findByConversationId(conversationId).stream().map(this::decrypt).toList();
72+
}
73+
74+
@Override
75+
public void saveAll(@NotNull String conversationId, List<Message> messages) {
76+
this.target.saveAll(conversationId, messages.stream().map(this::encrypt).collect(Collectors.toList()));
77+
}
78+
79+
@Override
80+
public void deleteByConversationId(@NotNull String conversationId) {
81+
this.target.deleteByConversationId(conversationId);
82+
}
83+
84+
}

spring-ai-model/src/main/java/org/springframework/ai/chat/messages/AssistantMessage.java

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616

1717
package org.springframework.ai.chat.messages;
1818

19-
import java.util.List;
20-
import java.util.Map;
21-
import java.util.Objects;
22-
2319
import org.springframework.ai.content.Media;
2420
import org.springframework.ai.content.MediaContent;
21+
import org.springframework.core.io.Resource;
22+
import org.springframework.lang.Nullable;
2523
import org.springframework.util.Assert;
2624
import org.springframework.util.CollectionUtils;
25+
import org.springframework.util.StringUtils;
26+
27+
import java.util.*;
2728

2829
/**
2930
* Lets the generative know the content was generated as a response to the user. This role
@@ -33,6 +34,7 @@
3334
*
3435
* @author Mark Pollack
3536
* @author Christian Tzolov
37+
* @author Josh Long
3638
* @since 1.0.0
3739
*/
3840
public class AssistantMessage extends AbstractMessage implements MediaContent {
@@ -104,4 +106,81 @@ public record ToolCall(String id, String type, String name, String arguments) {
104106

105107
}
106108

109+
//
110+
111+
public AssistantMessage copy() {
112+
return new AssistantMessage.Builder().text(getText())
113+
.media(List.copyOf(getMedia()))
114+
.metadata(Map.copyOf(getMetadata()))
115+
.build();
116+
}
117+
118+
public AssistantMessage.Builder mutate() {
119+
return new AssistantMessage.Builder().text(getText())
120+
.media(List.copyOf(getMedia()))
121+
.metadata(Map.copyOf(getMetadata()));
122+
}
123+
124+
public static AssistantMessage.Builder builder() {
125+
return new AssistantMessage.Builder();
126+
}
127+
128+
public static class Builder {
129+
130+
@Nullable
131+
private String textContent;
132+
133+
@Nullable
134+
private Resource resource;
135+
136+
private List<Media> media = new ArrayList<>();
137+
138+
private List<ToolCall> toolCalls = new ArrayList<>();
139+
140+
private Map<String, Object> metadata = new HashMap<>();
141+
142+
public AssistantMessage.Builder text(String textContent) {
143+
this.textContent = textContent;
144+
return this;
145+
}
146+
147+
public AssistantMessage.Builder toolCalls(List<ToolCall> toolCalls) {
148+
this.toolCalls = toolCalls;
149+
return this;
150+
}
151+
152+
public AssistantMessage.Builder text(Resource resource) {
153+
this.resource = resource;
154+
return this;
155+
}
156+
157+
public AssistantMessage.Builder media(List<Media> media) {
158+
this.media = media;
159+
return this;
160+
}
161+
162+
public AssistantMessage.Builder media(@Nullable Media... media) {
163+
if (media != null) {
164+
this.media = Arrays.asList(media);
165+
}
166+
return this;
167+
}
168+
169+
public AssistantMessage.Builder metadata(Map<String, Object> metadata) {
170+
this.metadata = metadata;
171+
return this;
172+
}
173+
174+
public AssistantMessage build() {
175+
if (StringUtils.hasText(textContent) && resource != null) {
176+
throw new IllegalArgumentException("textContent and resource cannot be set at the same time");
177+
}
178+
else if (resource != null) {
179+
this.textContent = MessageUtils.readResource(resource);
180+
}
181+
return new AssistantMessage(this.textContent, this.metadata, this.toolCalls, this.media);
182+
}
183+
184+
}
185+
107186
}

0 commit comments

Comments
 (0)