Skip to content

Commit 36c654a

Browse files
committed
test: Add unit tests for the SafeGuardAdvisor
Signed-off-by: Sun Yuhan <[email protected]>
1 parent 23ea127 commit 36c654a

File tree

2 files changed

+230
-2
lines changed

2 files changed

+230
-2
lines changed

spring-ai-client-chat/pom.xml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,14 @@
116116
<version>${mockk-jvm.version}</version>
117117
<scope>test</scope>
118118
</dependency>
119-
120-
</dependencies>
119+
120+
<dependency>
121+
<groupId>io.projectreactor</groupId>
122+
<artifactId>reactor-test</artifactId>
123+
<scope>test</scope>
124+
</dependency>
125+
126+
</dependencies>
121127

122128
<profiles>
123129
<!-- ANTLR profile moved to spring-ai-vector-store -->
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.client.advisor;
18+
19+
import java.util.List;
20+
import java.util.Map;
21+
22+
import io.micrometer.observation.ObservationRegistry;
23+
import org.junit.jupiter.api.BeforeEach;
24+
import org.junit.jupiter.api.Test;
25+
import reactor.core.publisher.Flux;
26+
import reactor.test.StepVerifier;
27+
28+
import org.springframework.ai.chat.client.ChatClientRequest;
29+
import org.springframework.ai.chat.client.ChatClientResponse;
30+
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
31+
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
32+
import org.springframework.ai.chat.memory.ChatMemory;
33+
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
34+
import org.springframework.ai.chat.messages.AssistantMessage;
35+
import org.springframework.ai.chat.messages.Message;
36+
import org.springframework.ai.chat.model.ChatResponse;
37+
import org.springframework.ai.chat.model.Generation;
38+
import org.springframework.ai.chat.prompt.Prompt;
39+
import org.springframework.core.Ordered;
40+
41+
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
42+
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
43+
import static org.mockito.ArgumentMatchers.any;
44+
import static org.mockito.Mockito.mock;
45+
import static org.mockito.Mockito.never;
46+
import static org.mockito.Mockito.verify;
47+
import static org.mockito.Mockito.when;
48+
49+
/**
50+
* Unit tests for {@link SafeGuardAdvisor}.
51+
*
52+
* @author Sun Yuhan
53+
*/
54+
class SafeGuardAdvisorTests {
55+
56+
private ChatClientRequest safeRequest;
57+
58+
private ChatClientRequest unsafeRequest;
59+
60+
@BeforeEach
61+
void setUp() {
62+
safeRequest = new ChatClientRequest(Prompt.builder().content("hello world").build(), Map.of());
63+
unsafeRequest = new ChatClientRequest(Prompt.builder().content("this contains secret").build(), Map.of());
64+
}
65+
66+
@Test
67+
void constructorThrowsExceptionWhenSensitiveWordsIsNull() {
68+
assertThatThrownBy(() -> new SafeGuardAdvisor(null)).isInstanceOf(IllegalArgumentException.class)
69+
.hasMessageContaining("Sensitive words must not be null");
70+
}
71+
72+
@Test
73+
void constructorThrowsExceptionWhenFailureResponseIsNull() {
74+
assertThatThrownBy(() -> new SafeGuardAdvisor(List.of("s"), null, 1))
75+
.isInstanceOf(IllegalArgumentException.class)
76+
.hasMessageContaining("Failure response must not be null");
77+
}
78+
79+
@Test
80+
void adviseCallInterceptsWhenContainsSensitiveWord() {
81+
SafeGuardAdvisor advisor = new SafeGuardAdvisor(List.of("secret"));
82+
83+
CallAdvisorChain mockChain = mock(CallAdvisorChain.class);
84+
85+
ChatClientResponse response = advisor.adviseCall(unsafeRequest, mockChain);
86+
87+
assertThat(response.chatResponse().getResult().getOutput().getText()).contains("I'm unable to respond");
88+
verify(mockChain, never()).nextCall(any());
89+
}
90+
91+
@Test
92+
void adviseCallPassesThroughWhenNoSensitiveWord() {
93+
SafeGuardAdvisor advisor = new SafeGuardAdvisor(List.of("secret"));
94+
95+
CallAdvisorChain mockChain = mock(CallAdvisorChain.class);
96+
ChatClientResponse expected = ChatClientResponse.builder()
97+
.chatResponse(ChatResponse.builder().generations(List.of()).build())
98+
.context(Map.of())
99+
.build();
100+
101+
when(mockChain.nextCall(safeRequest)).thenReturn(expected);
102+
103+
ChatClientResponse response = advisor.adviseCall(safeRequest, mockChain);
104+
105+
assertThat(response).isSameAs(expected);
106+
verify(mockChain).nextCall(safeRequest);
107+
}
108+
109+
@Test
110+
void adviseStreamInterceptsWhenContainsSensitiveWord() {
111+
SafeGuardAdvisor advisor = new SafeGuardAdvisor(List.of("secret"));
112+
113+
StreamAdvisorChain mockChain = mock(StreamAdvisorChain.class);
114+
115+
Flux<ChatClientResponse> flux = advisor.adviseStream(unsafeRequest, mockChain);
116+
117+
StepVerifier.create(flux)
118+
.assertNext(r -> assertThat(r.chatResponse().getResult().getOutput().getText())
119+
.contains("I'm unable to respond"))
120+
.verifyComplete();
121+
122+
verify(mockChain, never()).nextStream(any());
123+
}
124+
125+
@Test
126+
void adviseStreamPassesThroughWhenNoSensitiveWord() {
127+
SafeGuardAdvisor advisor = new SafeGuardAdvisor(List.of("secret"));
128+
129+
StreamAdvisorChain mockChain = mock(StreamAdvisorChain.class);
130+
ChatClientResponse expected = ChatClientResponse.builder()
131+
.chatResponse(ChatResponse.builder().generations(List.of()).build())
132+
.context(Map.of())
133+
.build();
134+
135+
when(mockChain.nextStream(safeRequest)).thenReturn(Flux.just(expected));
136+
137+
Flux<ChatClientResponse> flux = advisor.adviseStream(safeRequest, mockChain);
138+
139+
StepVerifier.create(flux).expectNext(expected).verifyComplete();
140+
141+
verify(mockChain).nextStream(safeRequest);
142+
}
143+
144+
@Test
145+
void builderRespectsCustomValues() {
146+
SafeGuardAdvisor advisor = SafeGuardAdvisor.builder()
147+
.sensitiveWords(List.of("xxx"))
148+
.failureResponse("custom block")
149+
.order(42)
150+
.build();
151+
152+
assertThat(advisor.getOrder()).isEqualTo(42);
153+
154+
CallAdvisorChain mockChain = mock(CallAdvisorChain.class);
155+
ChatClientRequest req = new ChatClientRequest(Prompt.builder().content("xxx is here").build(), Map.of());
156+
157+
ChatClientResponse response = advisor.adviseCall(req, mockChain);
158+
assertThat(response.chatResponse().getResult().getOutput().getText()).isEqualTo("custom block");
159+
}
160+
161+
@Test
162+
void safeGuardAdvisorShouldAllowSubsequentValidMessagesAfterBlockingSensitiveContent() {
163+
ChatMemory chatMemory = MessageWindowChatMemory.builder().build();
164+
165+
// Create three advisors
166+
SafeGuardAdvisor safeGuardAdvisor = new SafeGuardAdvisor(List.of("secret"));
167+
MessageChatMemoryAdvisor memoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory).build();
168+
ChatModelCallAdvisor chatModelCallAdvisor = mock(ChatModelCallAdvisor.class);
169+
170+
when(chatModelCallAdvisor.adviseCall(any(), any())).thenReturn(ChatClientResponse.builder()
171+
.chatResponse(ChatResponse.builder()
172+
.generations(List.of(new Generation(new AssistantMessage("Hello, how can I help?"))))
173+
.build())
174+
.build());
175+
when(chatModelCallAdvisor.getName()).thenReturn(ChatModelCallAdvisor.class.getSimpleName());
176+
when(chatModelCallAdvisor.getOrder()).thenReturn(Ordered.LOWEST_PRECEDENCE);
177+
178+
var advisors = List.of(safeGuardAdvisor, memoryAdvisor, chatModelCallAdvisor);
179+
180+
// Verify that SafeGuardAdvisor's order is higher than MessageChatMemoryAdvisor
181+
assertThat(safeGuardAdvisor.getOrder()).isLessThan(memoryAdvisor.getOrder());
182+
183+
// Create a request containing sensitive words
184+
ChatClientRequest sensitiveRequest = new ChatClientRequest(
185+
Prompt.builder().content("this contains secret").build(), Map.of());
186+
187+
// Create a normal request
188+
ChatClientRequest normalRequest = new ChatClientRequest(Prompt.builder().content("hello world").build(),
189+
Map.of());
190+
191+
// Send a message with sensitive words - should be intercepted by SafeGuardAdvisor
192+
DefaultAroundAdvisorChain aroundAdvisorChain1 = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
193+
.pushAll(advisors)
194+
.build();
195+
196+
ChatClientResponse response1 = aroundAdvisorChain1.nextCall(sensitiveRequest);
197+
assertThat(response1.chatResponse().getResult().getOutput().getText()).contains("I'm unable to respond");
198+
199+
// Verify that sensitive messages are not added to chat memory (intercepted by
200+
// SafeGuardAdvisor)
201+
List<Message> memoryMessagesAfterSensitive = chatMemory.get(ChatMemory.DEFAULT_CONVERSATION_ID);
202+
assertThat(memoryMessagesAfterSensitive.size()).isEqualTo(0);
203+
204+
// Send a normal message - should be processed normally
205+
DefaultAroundAdvisorChain aroundAdvisorChain2 = DefaultAroundAdvisorChain.builder(ObservationRegistry.NOOP)
206+
.pushAll(advisors)
207+
.build();
208+
209+
ChatClientResponse response2 = aroundAdvisorChain2.nextCall(normalRequest);
210+
assertThat(response2.chatResponse().getResult().getOutput().getText()).isEqualTo("Hello, how can I help?");
211+
212+
// Verify that chat memory contains only normal messages
213+
List<Message> memoryMessagesNormalRequest = chatMemory.get(ChatMemory.DEFAULT_CONVERSATION_ID);
214+
215+
assertThat(memoryMessagesNormalRequest.size() == 2).isTrue();
216+
List<String> messageTexts = memoryMessagesNormalRequest.stream().map(Message::getText).toList();
217+
218+
assertThat(messageTexts.contains("hello world")).isTrue();
219+
assertThat(messageTexts.contains("Hello, how can I help?")).isTrue();
220+
}
221+
222+
}

0 commit comments

Comments
 (0)