Skip to content

Commit 901340c

Browse files
authored
Merge pull request #1884 from ejstuart/feature/redis-memory-fix
Use upstream LangChain4j mixins for codec
2 parents ef8d11b + db566da commit 901340c

File tree

2 files changed

+123
-2
lines changed

2 files changed

+123
-2
lines changed

core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusJsonCodecFactory.java

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
import java.util.regex.Pattern;
1111

1212
import com.fasterxml.jackson.annotation.JsonAutoDetect;
13+
import com.fasterxml.jackson.annotation.JsonCreator;
1314
import com.fasterxml.jackson.annotation.JsonInclude;
15+
import com.fasterxml.jackson.annotation.JsonProperty;
16+
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
17+
import com.fasterxml.jackson.annotation.JsonSubTypes;
18+
import com.fasterxml.jackson.annotation.JsonTypeInfo;
1419
import com.fasterxml.jackson.annotation.PropertyAccessor;
1520
import com.fasterxml.jackson.core.JsonParseException;
1621
import com.fasterxml.jackson.core.JsonProcessingException;
@@ -21,9 +26,18 @@
2126
import com.fasterxml.jackson.databind.ObjectWriter;
2227
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
2328
import com.fasterxml.jackson.databind.SerializationFeature;
29+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
2430
import com.fasterxml.jackson.databind.module.SimpleDeserializers;
2531
import com.fasterxml.jackson.databind.module.SimpleModule;
2632

33+
import dev.langchain4j.agent.tool.ToolExecutionRequest;
34+
import dev.langchain4j.data.message.AiMessage;
35+
import dev.langchain4j.data.message.ChatMessage;
36+
import dev.langchain4j.data.message.ChatMessageType;
37+
import dev.langchain4j.data.message.CustomMessage;
38+
import dev.langchain4j.data.message.SystemMessage;
39+
import dev.langchain4j.data.message.ToolExecutionResultMessage;
40+
import dev.langchain4j.data.message.UserMessage;
2741
import dev.langchain4j.internal.Json;
2842
import dev.langchain4j.spi.json.JsonCodecFactory;
2943
import io.quarkiverse.langchain4j.runtime.jackson.CustomLocalDateDeserializer;
@@ -104,16 +118,89 @@ public static class ObjectMapperHolder {
104118
public static final ObjectWriter WRITER;
105119

106120
static {
121+
// Start with Arc container ObjectMapper to preserve Quarkus integration
107122
MAPPER = Arc.container().instance(ObjectMapper.class).get()
108123
.copy()
109124
.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE)
110125
.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY)
111-
.configure(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS.mappedFeature(), true)
112-
.registerModule(SnakeCaseObjectMapperHolder.QuarkusLangChain4jModule.INSTANCE);
126+
.configure(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS.mappedFeature(), true);
127+
128+
// Add chat message mixins to preserve thinking field deserialization
129+
MAPPER.addMixIn(ChatMessage.class, ChatMessageMixin.class);
130+
MAPPER.addMixIn(AiMessage.class, AiMessageMixin.class);
131+
MAPPER.addMixIn(UserMessage.class, UserMessageMixin.class);
132+
MAPPER.addMixIn(SystemMessage.class, SystemMessageMixin.class);
133+
MAPPER.addMixIn(ToolExecutionResultMessage.class, ToolExecutionResultMessageMixin.class);
134+
MAPPER.addMixIn(CustomMessage.class, CustomMessageMixin.class);
135+
MAPPER.addMixIn(ToolExecutionRequest.class, ToolExecutionRequestMixin.class);
136+
137+
// Register Quarkus-specific module
138+
MAPPER.registerModule(SnakeCaseObjectMapperHolder.QuarkusLangChain4jModule.INSTANCE);
139+
113140
WRITER = MAPPER.writerWithDefaultPrettyPrinter();
114141
}
115142
}
116143

144+
/**
145+
* Jackson mixins for chat message deserialization.
146+
* These enable proper deserialization of chat messages including the thinking field in AiMessage.
147+
* Based on mixins from dev.langchain4j.data.message.JacksonChatMessageJsonCodec.
148+
*/
149+
@JsonInclude(JsonInclude.Include.NON_NULL)
150+
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type")
151+
@JsonSubTypes({
152+
@JsonSubTypes.Type(value = SystemMessage.class, name = "SYSTEM"),
153+
@JsonSubTypes.Type(value = UserMessage.class, name = "USER"),
154+
@JsonSubTypes.Type(value = AiMessage.class, name = "AI"),
155+
@JsonSubTypes.Type(value = ToolExecutionResultMessage.class, name = "TOOL_EXECUTION_RESULT"),
156+
@JsonSubTypes.Type(value = CustomMessage.class, name = "CUSTOM"),
157+
})
158+
private abstract static class ChatMessageMixin {
159+
@JsonProperty
160+
public abstract ChatMessageType type();
161+
}
162+
163+
@JsonInclude(JsonInclude.Include.NON_NULL)
164+
private abstract static class SystemMessageMixin {
165+
@JsonCreator
166+
public SystemMessageMixin(@JsonProperty("text") String text) {
167+
}
168+
}
169+
170+
@JsonInclude(JsonInclude.Include.NON_NULL)
171+
@JsonDeserialize(builder = UserMessage.Builder.class)
172+
private abstract static class UserMessageMixin {
173+
}
174+
175+
@JsonInclude(JsonInclude.Include.NON_NULL)
176+
@JsonDeserialize(builder = AiMessage.Builder.class)
177+
@JsonPropertyOrder({ "toolExecutionRequests", "text", "attributes", "type" })
178+
private abstract static class AiMessageMixin {
179+
}
180+
181+
@JsonInclude(JsonInclude.Include.NON_NULL)
182+
@JsonPropertyOrder({ "text", "id", "toolName", "type" })
183+
private static class ToolExecutionResultMessageMixin {
184+
@JsonCreator
185+
public ToolExecutionResultMessageMixin(
186+
@JsonProperty("id") String id,
187+
@JsonProperty("toolName") String toolName,
188+
@JsonProperty("text") String text) {
189+
}
190+
}
191+
192+
@JsonInclude(JsonInclude.Include.NON_NULL)
193+
private static class CustomMessageMixin {
194+
@JsonCreator
195+
public CustomMessageMixin(@JsonProperty("attributes") Map<String, Object> attributes) {
196+
}
197+
}
198+
199+
@JsonInclude(JsonInclude.Include.NON_NULL)
200+
@JsonDeserialize(builder = ToolExecutionRequest.Builder.class)
201+
private abstract static class ToolExecutionRequestMixin {
202+
}
203+
117204
public static class SnakeCaseObjectMapperHolder {
118205
public static final ObjectMapper MAPPER = Arc.container().instance(ObjectMapper.class).get()
119206
.copy()

memory-stores/memory-store-redis/deployment/src/test/java/io/quarkiverse/langchain4j/memorystore/redis/test/RedisChatMemoryStoreTest.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.junit.jupiter.api.Test;
1919
import org.junit.jupiter.api.extension.RegisterExtension;
2020

21+
import dev.langchain4j.data.message.AiMessage;
2122
import dev.langchain4j.data.message.ChatMessage;
2223
import dev.langchain4j.service.MemoryId;
2324
import dev.langchain4j.service.UserMessage;
@@ -167,4 +168,37 @@ void should_keep_separate_chat_memory_for_each_user_in_store() throws IOExceptio
167168
assertThat(redisDataSource.key().exists("" + FIRST_MEMORY_ID, "" + SECOND_MEMORY_ID)).isEqualTo(0);
168169
}
169170

171+
@Test
172+
void should_persist_ai_message_thinking_field() {
173+
// assert the bean type is correct
174+
assertThat(chatMemoryStore).isInstanceOf(RedisChatMemoryStore.class);
175+
176+
// Create an AiMessage with thinking field
177+
AiMessage messageWithThinking = AiMessage.builder()
178+
.text("The answer is 42")
179+
.thinking("Let me reason through this carefully. The question asks about the meaning of life...")
180+
.build();
181+
182+
int memoryId = 999;
183+
184+
// Store the message
185+
chatMemoryStore.updateMessages(memoryId, List.of(messageWithThinking));
186+
187+
// Retrieve the messages
188+
List<ChatMessage> retrievedMessages = chatMemoryStore.getMessages(memoryId);
189+
190+
// Assert the message was retrieved and the thinking field is preserved
191+
assertThat(retrievedMessages).hasSize(1);
192+
ChatMessage retrievedMessage = retrievedMessages.get(0);
193+
assertThat(retrievedMessage).isInstanceOf(AiMessage.class);
194+
195+
AiMessage retrievedAiMessage = (AiMessage) retrievedMessage;
196+
assertThat(retrievedAiMessage.text()).isEqualTo("The answer is 42");
197+
assertThat(retrievedAiMessage.thinking()).isEqualTo(
198+
"Let me reason through this carefully. The question asks about the meaning of life...");
199+
200+
// Clean up
201+
chatMemoryStore.deleteMessages(memoryId);
202+
}
203+
170204
}

0 commit comments

Comments
 (0)