Skip to content

Commit 479282f

Browse files
committed
fix: ensure system role is first in advised request messages
Signed-off-by: Alexandros Pappas <[email protected]>
1 parent d25d37a commit 479282f

File tree

2 files changed

+96
-5
lines changed

2 files changed

+96
-5
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
* @author Christian Tzolov
6161
* @author Thomas Vitale
6262
* @author Ilayaperumal Gopinathan
63+
* @author Alexandros Pappas
6364
* @since 1.0.0
6465
*/
6566
public record AdvisedRequest(
@@ -147,18 +148,25 @@ public AdvisedRequest updateContext(Function<Map<String, Object>, Map<String, Ob
147148
}
148149

149150
public Prompt toPrompt() {
150-
var messages = new ArrayList<>(this.messages());
151+
List<Message> promptMessages = new ArrayList<>();
151152

153+
// 1. Always add the SystemMessage first (if present)
152154
String processedSystemText = this.systemText();
153155
if (StringUtils.hasText(processedSystemText)) {
154156
if (!CollectionUtils.isEmpty(this.systemParams())) {
155157
processedSystemText = new PromptTemplate(processedSystemText, this.systemParams()).render();
156158
}
157-
messages.add(new SystemMessage(processedSystemText));
159+
promptMessages.add(new SystemMessage(processedSystemText));
158160
}
159161

160-
String formatParam = (String) this.adviseContext().get("formatParam");
162+
// 2. Add any existing conversation messages
163+
List<Message> existingMessages = this.messages();
164+
if (!CollectionUtils.isEmpty(existingMessages)) {
165+
promptMessages.addAll(existingMessages);
166+
}
161167

168+
// 3. Process and append the UserMessage (if present)
169+
String formatParam = (String) this.adviseContext().get("formatParam");
162170
var processedUserText = StringUtils.hasText(formatParam)
163171
? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText();
164172

@@ -170,9 +178,10 @@ public Prompt toPrompt() {
170178
if (!CollectionUtils.isEmpty(userParams)) {
171179
processedUserText = new PromptTemplate(processedUserText, userParams).render();
172180
}
173-
messages.add(new UserMessage(processedUserText, this.media()));
181+
promptMessages.add(new UserMessage(processedUserText, this.media()));
174182
}
175183

184+
// 4. Configure function-calling options
176185
if (this.chatOptions() instanceof FunctionCallingOptions functionCallingOptions) {
177186
if (!this.functionNames().isEmpty()) {
178187
functionCallingOptions.setFunctions(new HashSet<>(this.functionNames()));
@@ -185,7 +194,7 @@ public Prompt toPrompt() {
185194
}
186195
}
187196

188-
return new Prompt(messages, this.chatOptions());
197+
return new Prompt(promptMessages, chatOptions);
189198
}
190199

191200
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package org.springframework.ai.integration.tests.client.advisor;
2+
3+
import org.junit.jupiter.api.Test;
4+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
5+
import org.springframework.ai.chat.client.ChatClient;
6+
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
7+
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
8+
import org.springframework.ai.chat.memory.InMemoryChatMemory;
9+
import org.springframework.ai.chat.model.ChatResponse;
10+
import org.springframework.ai.openai.OpenAiChatModel;
11+
import org.springframework.beans.factory.annotation.Autowired;
12+
import org.springframework.boot.test.context.SpringBootTest;
13+
14+
import static org.assertj.core.api.Assertions.assertThat;
15+
16+
/**
17+
* Integration tests for {@link MessageChatMemoryAdvisor}.
18+
*
19+
* @author Alexandros Pappas
20+
*/
21+
@SpringBootTest
22+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
23+
class MessageChatMemoryAdvisorIT {
24+
25+
@Autowired
26+
OpenAiChatModel openAiChatModel;
27+
28+
@Test
29+
void chatMemoryStoresAndRecallsConversation() {
30+
var chatMemory = new InMemoryChatMemory();
31+
var conversationId = "test-conversation";
32+
33+
var memoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory).conversationId(conversationId).build();
34+
35+
var chatClient = ChatClient.builder(openAiChatModel).defaultAdvisors(memoryAdvisor).build();
36+
37+
// First interaction
38+
ChatResponse response1 = chatClient.prompt().user("Hello, my name is John.").call().chatResponse();
39+
40+
assertThat(response1).isNotNull();
41+
String assistantReply1 = response1.getResult().getOutput().getText();
42+
System.out.println("Assistant reply 1: " + assistantReply1);
43+
44+
// Second interaction - Verify memory recall
45+
ChatResponse response2 = chatClient.prompt().user("What is my name?").call().chatResponse();
46+
47+
assertThat(response2).isNotNull();
48+
String assistantReply2 = response2.getResult().getOutput().getText();
49+
System.out.println("Assistant reply 2: " + assistantReply2);
50+
51+
assertThat(assistantReply2.toLowerCase()).contains("john");
52+
}
53+
54+
@Test
55+
void separateConversationsDoNotMixMemory() {
56+
var chatMemory = new InMemoryChatMemory();
57+
58+
var memoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory).build();
59+
60+
var chatClient = ChatClient.builder(openAiChatModel).defaultAdvisors(memoryAdvisor).build();
61+
62+
// First conversation
63+
chatClient.prompt()
64+
.user("Remember my secret code is blue.")
65+
.advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, "conv-1"))
66+
.call();
67+
68+
// Second conversation
69+
ChatResponse response = chatClient.prompt()
70+
.user("Do you remember my secret code?")
71+
.advisors(advisors -> advisors.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, "conv-2"))
72+
.call()
73+
.chatResponse();
74+
75+
assertThat(response).isNotNull();
76+
String assistantReply = response.getResult().getOutput().getText();
77+
System.out.println("Assistant reply: " + assistantReply);
78+
79+
assertThat(assistantReply.toLowerCase()).doesNotContain("blue");
80+
}
81+
82+
}

0 commit comments

Comments
 (0)