From 7ff94619da35217030a063f9c80e67a0e22b2cc0 Mon Sep 17 00:00:00 2001 From: YunKui Lu Date: Fri, 20 Jun 2025 21:06:38 +0800 Subject: [PATCH] feat: Enhance the SafeGuardAdvisor - Allow getting matched sensitive words through `context`. - The matching of sensitive words is case-insensitive. - Added unit tests. Fixed #3586 Signed-off-by: YunKui Lu --- .../chat/client/advisor/SafeGuardAdvisor.java | 34 +++- .../client/advisor/SafeGuardAdvisorTests.java | 148 ++++++++++++++++++ 2 files changed, 174 insertions(+), 8 deletions(-) create mode 100644 spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisorTests.java diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java index a319f9dd4f3..f2370bca251 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java @@ -16,6 +16,7 @@ package org.springframework.ai.chat.client.advisor; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -44,6 +45,8 @@ */ public class SafeGuardAdvisor implements CallAdvisor, StreamAdvisor { + public static final String CONTAINS_SENSITIVE_WORDS = "safe_guard_contains_sensitive_words"; + private static final String DEFAULT_FAILURE_RESPONSE = "I'm unable to respond to that due to sensitive content. Could we rephrase or discuss something else?"; private static final int DEFAULT_ORDER = 0; @@ -70,15 +73,21 @@ public static Builder builder() { return new Builder(); } + @Override public String getName() { return this.getClass().getSimpleName(); } @Override public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { - if (!CollectionUtils.isEmpty(this.sensitiveWords) - && this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) { - return createFailureResponse(chatClientRequest); + if (!CollectionUtils.isEmpty(this.sensitiveWords)) { + String lowerCaseContents = chatClientRequest.prompt().getContents().toLowerCase(); + List hitSensitiveWords = this.sensitiveWords.stream() + .filter(w -> lowerCaseContents.contains(w.toLowerCase())) + .toList(); + if (!CollectionUtils.isEmpty(hitSensitiveWords)) { + return createFailureResponse(chatClientRequest, hitSensitiveWords); + } } return callAdvisorChain.nextCall(chatClientRequest); @@ -87,20 +96,29 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd @Override public Flux adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { - if (!CollectionUtils.isEmpty(this.sensitiveWords) - && this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) { - return Flux.just(createFailureResponse(chatClientRequest)); + if (!CollectionUtils.isEmpty(this.sensitiveWords)) { + String lowerCaseContents = chatClientRequest.prompt().getContents().toLowerCase(); + List hitSensitiveWords = this.sensitiveWords.stream() + .filter(w -> lowerCaseContents.contains(w.toLowerCase())) + .toList(); + if (!CollectionUtils.isEmpty(hitSensitiveWords)) { + return Flux.just(createFailureResponse(chatClientRequest, hitSensitiveWords)); + } } return streamAdvisorChain.nextStream(chatClientRequest); } - private ChatClientResponse createFailureResponse(ChatClientRequest chatClientRequest) { + private ChatClientResponse createFailureResponse(ChatClientRequest chatClientRequest, + List hitSensitiveWords) { + Map context = new HashMap<>(chatClientRequest.context()); + context.put(CONTAINS_SENSITIVE_WORDS, hitSensitiveWords); + return ChatClientResponse.builder() .chatResponse(ChatResponse.builder() .generations(List.of(new Generation(new AssistantMessage(this.failureResponse)))) .build()) - .context(Map.copyOf(chatClientRequest.context())) + .context(context) .build(); } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisorTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisorTests.java new file mode 100644 index 00000000000..6d644abb29e --- /dev/null +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisorTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client.advisor; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatModel; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * @author YunKui Lu + */ +@ExtendWith(MockitoExtension.class) +class SafeGuardAdvisorTests { + + @Mock + ChatModel chatModel; + + @Test + void whenSensitiveWordsIsNullThenThrow() { + assertThatThrownBy(() -> SafeGuardAdvisor.builder().sensitiveWords(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Sensitive words must not be null!"); + } + + @Test + void whenFailureResponseIsNullThenThrow() { + assertThatThrownBy(() -> SafeGuardAdvisor.builder().sensitiveWords(List.of()).failureResponse(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Failure response must not be null!"); + } + + @Test + void testBuilderMethodChaining() { + // Test builder method chaining with methods from AbstractBuilder and + // SafeGuardAdvisor.Builder + List sensitiveWords = List.of("word1", "word2"); + int customOrder = 42; + String failureResponse = "That topic may be too sensitive to discuss. Can we talk about something else instead?"; + + SafeGuardAdvisor advisor = SafeGuardAdvisor.builder() + .sensitiveWords(sensitiveWords) + .failureResponse(failureResponse) + .order(customOrder) + .build(); + + // Verify the advisor was built with the correct properties + assertThat(advisor).isNotNull(); + assertThat(advisor.getOrder()).isEqualTo(customOrder); + } + + @Test + void testDefaultValues() { + // Test builder method chaining with methods from AbstractBuilder and + // SafeGuardAdvisor.Builder + List sensitiveWords = List.of("word1", "word2"); + String failureResponse = "That topic may be too sensitive to discuss. Can we talk about something else instead?"; + + SafeGuardAdvisor advisor = SafeGuardAdvisor.builder() + .sensitiveWords(sensitiveWords) + .failureResponse(failureResponse) + .build(); + + // Verify default values + assertThat(advisor).isNotNull(); + assertThat(advisor.getOrder()).isZero(); + } + + @Test + void callAdvisorsContextPropagation() { + List sensitiveWords = List.of("word1", "word2"); + String failureResponse = "That topic may be too sensitive to discuss. Can we talk about something else instead?"; + + SafeGuardAdvisor advisor = SafeGuardAdvisor.builder() + .sensitiveWords(sensitiveWords) + .failureResponse(failureResponse) + .build(); + + var chatClient = ChatClient.builder(this.chatModel) + .defaultSystem("Default system text.") + .defaultAdvisors(advisor) + .build(); + + var chatClientResponse = chatClient.prompt() + // should be case-insensitive + .user("do you like Word1?") + .advisors(advisor) + .call() + .chatClientResponse(); + + assertThat(chatClientResponse.chatResponse()).isNotNull(); + assertThat(chatClientResponse.chatResponse().getResult().getOutput().getText()).isEqualTo(failureResponse); + assertThat((List) chatClientResponse.context().get(SafeGuardAdvisor.CONTAINS_SENSITIVE_WORDS)) + .containsExactly("word1"); + } + + @Test + void streamAdvisorsContextPropagation() { + List sensitiveWords = List.of("Word1", "Word2"); + String failureResponse = "That topic may be too sensitive to discuss. Can we talk about something else instead?"; + + SafeGuardAdvisor advisor = SafeGuardAdvisor.builder() + .sensitiveWords(sensitiveWords) + .failureResponse(failureResponse) + .build(); + + var chatClient = ChatClient.builder(this.chatModel) + .defaultSystem("Default system text.") + .defaultAdvisors(advisor) + .build(); + + var chatClientResponse = chatClient.prompt() + // should be case-insensitive + .user("do you like word2?") + .advisors(advisor) + .stream() + .chatClientResponse() + .blockFirst(); + + assertThat(chatClientResponse.chatResponse()).isNotNull(); + assertThat(chatClientResponse.chatResponse().getResult().getOutput().getText()).isEqualTo(failureResponse); + assertThat((List) chatClientResponse.context().get(SafeGuardAdvisor.CONTAINS_SENSITIVE_WORDS)) + .containsExactly("Word2"); + } + +}