diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java index a319f9dd4f3..df68a5012d2 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SafeGuardAdvisor.java @@ -18,7 +18,11 @@ import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClientRequest; @@ -44,6 +48,8 @@ */ public class SafeGuardAdvisor implements CallAdvisor, StreamAdvisor { + private static final Logger log = LoggerFactory.getLogger(SafeGuardAdvisor.class); + private static final String DEFAULT_FAILURE_RESPONSE = "I'm unable to respond to that due to sensitive content. Could we rephrase or discuss something else?"; private static final int DEFAULT_ORDER = 0; @@ -76,22 +82,28 @@ public String getName() { @Override public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { - if (!CollectionUtils.isEmpty(this.sensitiveWords) - && this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) { - return createFailureResponse(chatClientRequest); + if (!CollectionUtils.isEmpty(this.sensitiveWords)) { + String content = chatClientRequest.prompt().getContents(); + Set hitWords = this.sensitiveWords.stream().filter(content::contains).collect(Collectors.toSet()); + if (!hitWords.isEmpty()) { + log.debug("sensitive words found: {}", hitWords); + return createFailureResponse(chatClientRequest); + } } - return callAdvisorChain.nextCall(chatClientRequest); } @Override public Flux adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain) { - if (!CollectionUtils.isEmpty(this.sensitiveWords) - && this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) { - return Flux.just(createFailureResponse(chatClientRequest)); + if (!CollectionUtils.isEmpty(this.sensitiveWords)) { + String contents = chatClientRequest.prompt().getContents(); + Set hitWords = this.sensitiveWords.stream().filter(contents::contains).collect(Collectors.toSet()); + if (!hitWords.isEmpty()) { + log.debug("sensitive words found: {}", hitWords); + return Flux.just(createFailureResponse(chatClientRequest)); + } } - return streamAdvisorChain.nextStream(chatClientRequest); }