Skip to content

Commit 7ff9461

Browse files
committed
feat: Enhance the SafeGuardAdvisor
- Allow getting matched sensitive words through `context`. - The matching of sensitive words is case-insensitive. - Added unit tests. Fixed #3586 Signed-off-by: YunKui Lu <[email protected]>
1 parent 8107612 commit 7ff9461

File tree

2 files changed

+174
-8
lines changed

2 files changed

+174
-8
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.chat.client.advisor;
1818

19+
import java.util.HashMap;
1920
import java.util.List;
2021
import java.util.Map;
2122

@@ -44,6 +45,8 @@
4445
*/
4546
public class SafeGuardAdvisor implements CallAdvisor, StreamAdvisor {
4647

48+
public static final String CONTAINS_SENSITIVE_WORDS = "safe_guard_contains_sensitive_words";
49+
4750
private static final String DEFAULT_FAILURE_RESPONSE = "I'm unable to respond to that due to sensitive content. Could we rephrase or discuss something else?";
4851

4952
private static final int DEFAULT_ORDER = 0;
@@ -70,15 +73,21 @@ public static Builder builder() {
7073
return new Builder();
7174
}
7275

76+
@Override
7377
public String getName() {
7478
return this.getClass().getSimpleName();
7579
}
7680

7781
@Override
7882
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
79-
if (!CollectionUtils.isEmpty(this.sensitiveWords)
80-
&& this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) {
81-
return createFailureResponse(chatClientRequest);
83+
if (!CollectionUtils.isEmpty(this.sensitiveWords)) {
84+
String lowerCaseContents = chatClientRequest.prompt().getContents().toLowerCase();
85+
List<String> hitSensitiveWords = this.sensitiveWords.stream()
86+
.filter(w -> lowerCaseContents.contains(w.toLowerCase()))
87+
.toList();
88+
if (!CollectionUtils.isEmpty(hitSensitiveWords)) {
89+
return createFailureResponse(chatClientRequest, hitSensitiveWords);
90+
}
8291
}
8392

8493
return callAdvisorChain.nextCall(chatClientRequest);
@@ -87,20 +96,29 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd
8796
@Override
8897
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
8998
StreamAdvisorChain streamAdvisorChain) {
90-
if (!CollectionUtils.isEmpty(this.sensitiveWords)
91-
&& this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) {
92-
return Flux.just(createFailureResponse(chatClientRequest));
99+
if (!CollectionUtils.isEmpty(this.sensitiveWords)) {
100+
String lowerCaseContents = chatClientRequest.prompt().getContents().toLowerCase();
101+
List<String> hitSensitiveWords = this.sensitiveWords.stream()
102+
.filter(w -> lowerCaseContents.contains(w.toLowerCase()))
103+
.toList();
104+
if (!CollectionUtils.isEmpty(hitSensitiveWords)) {
105+
return Flux.just(createFailureResponse(chatClientRequest, hitSensitiveWords));
106+
}
93107
}
94108

95109
return streamAdvisorChain.nextStream(chatClientRequest);
96110
}
97111

98-
private ChatClientResponse createFailureResponse(ChatClientRequest chatClientRequest) {
112+
private ChatClientResponse createFailureResponse(ChatClientRequest chatClientRequest,
113+
List<String> hitSensitiveWords) {
114+
Map<String, Object> context = new HashMap<>(chatClientRequest.context());
115+
context.put(CONTAINS_SENSITIVE_WORDS, hitSensitiveWords);
116+
99117
return ChatClientResponse.builder()
100118
.chatResponse(ChatResponse.builder()
101119
.generations(List.of(new Generation(new AssistantMessage(this.failureResponse))))
102120
.build())
103-
.context(Map.copyOf(chatClientRequest.context()))
121+
.context(context)
104122
.build();
105123
}
106124

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.client.advisor;
18+
19+
import java.util.List;
20+
21+
import org.junit.jupiter.api.Test;
22+
import org.junit.jupiter.api.extension.ExtendWith;
23+
import org.mockito.Mock;
24+
import org.mockito.junit.jupiter.MockitoExtension;
25+
26+
import org.springframework.ai.chat.client.ChatClient;
27+
import org.springframework.ai.chat.model.ChatModel;
28+
29+
import static org.assertj.core.api.Assertions.assertThat;
30+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
31+
32+
/**
33+
* @author YunKui Lu
34+
*/
35+
@ExtendWith(MockitoExtension.class)
36+
class SafeGuardAdvisorTests {
37+
38+
@Mock
39+
ChatModel chatModel;
40+
41+
@Test
42+
void whenSensitiveWordsIsNullThenThrow() {
43+
assertThatThrownBy(() -> SafeGuardAdvisor.builder().sensitiveWords(null).build())
44+
.isInstanceOf(IllegalArgumentException.class)
45+
.hasMessageContaining("Sensitive words must not be null!");
46+
}
47+
48+
@Test
49+
void whenFailureResponseIsNullThenThrow() {
50+
assertThatThrownBy(() -> SafeGuardAdvisor.builder().sensitiveWords(List.of()).failureResponse(null).build())
51+
.isInstanceOf(IllegalArgumentException.class)
52+
.hasMessageContaining("Failure response must not be null!");
53+
}
54+
55+
@Test
56+
void testBuilderMethodChaining() {
57+
// Test builder method chaining with methods from AbstractBuilder and
58+
// SafeGuardAdvisor.Builder
59+
List<String> sensitiveWords = List.of("word1", "word2");
60+
int customOrder = 42;
61+
String failureResponse = "That topic may be too sensitive to discuss. Can we talk about something else instead?";
62+
63+
SafeGuardAdvisor advisor = SafeGuardAdvisor.builder()
64+
.sensitiveWords(sensitiveWords)
65+
.failureResponse(failureResponse)
66+
.order(customOrder)
67+
.build();
68+
69+
// Verify the advisor was built with the correct properties
70+
assertThat(advisor).isNotNull();
71+
assertThat(advisor.getOrder()).isEqualTo(customOrder);
72+
}
73+
74+
@Test
75+
void testDefaultValues() {
76+
// Test builder method chaining with methods from AbstractBuilder and
77+
// SafeGuardAdvisor.Builder
78+
List<String> sensitiveWords = List.of("word1", "word2");
79+
String failureResponse = "That topic may be too sensitive to discuss. Can we talk about something else instead?";
80+
81+
SafeGuardAdvisor advisor = SafeGuardAdvisor.builder()
82+
.sensitiveWords(sensitiveWords)
83+
.failureResponse(failureResponse)
84+
.build();
85+
86+
// Verify default values
87+
assertThat(advisor).isNotNull();
88+
assertThat(advisor.getOrder()).isZero();
89+
}
90+
91+
@Test
92+
void callAdvisorsContextPropagation() {
93+
List<String> sensitiveWords = List.of("word1", "word2");
94+
String failureResponse = "That topic may be too sensitive to discuss. Can we talk about something else instead?";
95+
96+
SafeGuardAdvisor advisor = SafeGuardAdvisor.builder()
97+
.sensitiveWords(sensitiveWords)
98+
.failureResponse(failureResponse)
99+
.build();
100+
101+
var chatClient = ChatClient.builder(this.chatModel)
102+
.defaultSystem("Default system text.")
103+
.defaultAdvisors(advisor)
104+
.build();
105+
106+
var chatClientResponse = chatClient.prompt()
107+
// should be case-insensitive
108+
.user("do you like Word1?")
109+
.advisors(advisor)
110+
.call()
111+
.chatClientResponse();
112+
113+
assertThat(chatClientResponse.chatResponse()).isNotNull();
114+
assertThat(chatClientResponse.chatResponse().getResult().getOutput().getText()).isEqualTo(failureResponse);
115+
assertThat((List<String>) chatClientResponse.context().get(SafeGuardAdvisor.CONTAINS_SENSITIVE_WORDS))
116+
.containsExactly("word1");
117+
}
118+
119+
@Test
120+
void streamAdvisorsContextPropagation() {
121+
List<String> sensitiveWords = List.of("Word1", "Word2");
122+
String failureResponse = "That topic may be too sensitive to discuss. Can we talk about something else instead?";
123+
124+
SafeGuardAdvisor advisor = SafeGuardAdvisor.builder()
125+
.sensitiveWords(sensitiveWords)
126+
.failureResponse(failureResponse)
127+
.build();
128+
129+
var chatClient = ChatClient.builder(this.chatModel)
130+
.defaultSystem("Default system text.")
131+
.defaultAdvisors(advisor)
132+
.build();
133+
134+
var chatClientResponse = chatClient.prompt()
135+
// should be case-insensitive
136+
.user("do you like word2?")
137+
.advisors(advisor)
138+
.stream()
139+
.chatClientResponse()
140+
.blockFirst();
141+
142+
assertThat(chatClientResponse.chatResponse()).isNotNull();
143+
assertThat(chatClientResponse.chatResponse().getResult().getOutput().getText()).isEqualTo(failureResponse);
144+
assertThat((List<String>) chatClientResponse.context().get(SafeGuardAdvisor.CONTAINS_SENSITIVE_WORDS))
145+
.containsExactly("Word2");
146+
}
147+
148+
}

0 commit comments

Comments
 (0)