Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.ai.chat.client.advisor;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -44,6 +45,8 @@
*/
public class SafeGuardAdvisor implements CallAdvisor, StreamAdvisor {

public static final String CONTAINS_SENSITIVE_WORDS = "safe_guard_contains_sensitive_words";

private static final String DEFAULT_FAILURE_RESPONSE = "I'm unable to respond to that due to sensitive content. Could we rephrase or discuss something else?";

private static final int DEFAULT_ORDER = 0;
Expand All @@ -70,15 +73,21 @@ public static Builder builder() {
return new Builder();
}

@Override
public String getName() {
return this.getClass().getSimpleName();
}

@Override
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
if (!CollectionUtils.isEmpty(this.sensitiveWords)
&& this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) {
return createFailureResponse(chatClientRequest);
if (!CollectionUtils.isEmpty(this.sensitiveWords)) {
String lowerCaseContents = chatClientRequest.prompt().getContents().toLowerCase();
List<String> hitSensitiveWords = this.sensitiveWords.stream()
.filter(w -> lowerCaseContents.contains(w.toLowerCase()))
.toList();
if (!CollectionUtils.isEmpty(hitSensitiveWords)) {
return createFailureResponse(chatClientRequest, hitSensitiveWords);
}
}

return callAdvisorChain.nextCall(chatClientRequest);
Expand All @@ -87,20 +96,29 @@ public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAd
@Override
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
StreamAdvisorChain streamAdvisorChain) {
if (!CollectionUtils.isEmpty(this.sensitiveWords)
&& this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) {
return Flux.just(createFailureResponse(chatClientRequest));
if (!CollectionUtils.isEmpty(this.sensitiveWords)) {
String lowerCaseContents = chatClientRequest.prompt().getContents().toLowerCase();
List<String> hitSensitiveWords = this.sensitiveWords.stream()
.filter(w -> lowerCaseContents.contains(w.toLowerCase()))
.toList();
if (!CollectionUtils.isEmpty(hitSensitiveWords)) {
return Flux.just(createFailureResponse(chatClientRequest, hitSensitiveWords));
}
}

return streamAdvisorChain.nextStream(chatClientRequest);
}

private ChatClientResponse createFailureResponse(ChatClientRequest chatClientRequest) {
private ChatClientResponse createFailureResponse(ChatClientRequest chatClientRequest,
List<String> hitSensitiveWords) {
Map<String, Object> context = new HashMap<>(chatClientRequest.context());
context.put(CONTAINS_SENSITIVE_WORDS, hitSensitiveWords);

return ChatClientResponse.builder()
.chatResponse(ChatResponse.builder()
.generations(List.of(new Generation(new AssistantMessage(this.failureResponse))))
.build())
.context(Map.copyOf(chatClientRequest.context()))
.context(context)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.client.advisor;

import java.util.List;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/**
* @author YunKui Lu
*/
@ExtendWith(MockitoExtension.class)
class SafeGuardAdvisorTests {

@Mock
ChatModel chatModel;

@Test
void whenSensitiveWordsIsNullThenThrow() {
assertThatThrownBy(() -> SafeGuardAdvisor.builder().sensitiveWords(null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Sensitive words must not be null!");
}

@Test
void whenFailureResponseIsNullThenThrow() {
assertThatThrownBy(() -> SafeGuardAdvisor.builder().sensitiveWords(List.of()).failureResponse(null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Failure response must not be null!");
}

@Test
void testBuilderMethodChaining() {
// Test builder method chaining with methods from AbstractBuilder and
// SafeGuardAdvisor.Builder
List<String> sensitiveWords = List.of("word1", "word2");
int customOrder = 42;
String failureResponse = "That topic may be too sensitive to discuss. Can we talk about something else instead?";

SafeGuardAdvisor advisor = SafeGuardAdvisor.builder()
.sensitiveWords(sensitiveWords)
.failureResponse(failureResponse)
.order(customOrder)
.build();

// Verify the advisor was built with the correct properties
assertThat(advisor).isNotNull();
assertThat(advisor.getOrder()).isEqualTo(customOrder);
}

@Test
void testDefaultValues() {
// Test builder method chaining with methods from AbstractBuilder and
// SafeGuardAdvisor.Builder
List<String> sensitiveWords = List.of("word1", "word2");
String failureResponse = "That topic may be too sensitive to discuss. Can we talk about something else instead?";

SafeGuardAdvisor advisor = SafeGuardAdvisor.builder()
.sensitiveWords(sensitiveWords)
.failureResponse(failureResponse)
.build();

// Verify default values
assertThat(advisor).isNotNull();
assertThat(advisor.getOrder()).isZero();
}

@Test
void callAdvisorsContextPropagation() {
List<String> sensitiveWords = List.of("word1", "word2");
String failureResponse = "That topic may be too sensitive to discuss. Can we talk about something else instead?";

SafeGuardAdvisor advisor = SafeGuardAdvisor.builder()
.sensitiveWords(sensitiveWords)
.failureResponse(failureResponse)
.build();

var chatClient = ChatClient.builder(this.chatModel)
.defaultSystem("Default system text.")
.defaultAdvisors(advisor)
.build();

var chatClientResponse = chatClient.prompt()
// should be case-insensitive
.user("do you like Word1?")
.advisors(advisor)
.call()
.chatClientResponse();

assertThat(chatClientResponse.chatResponse()).isNotNull();
assertThat(chatClientResponse.chatResponse().getResult().getOutput().getText()).isEqualTo(failureResponse);
assertThat((List<String>) chatClientResponse.context().get(SafeGuardAdvisor.CONTAINS_SENSITIVE_WORDS))
.containsExactly("word1");
}

@Test
void streamAdvisorsContextPropagation() {
List<String> sensitiveWords = List.of("Word1", "Word2");
String failureResponse = "That topic may be too sensitive to discuss. Can we talk about something else instead?";

SafeGuardAdvisor advisor = SafeGuardAdvisor.builder()
.sensitiveWords(sensitiveWords)
.failureResponse(failureResponse)
.build();

var chatClient = ChatClient.builder(this.chatModel)
.defaultSystem("Default system text.")
.defaultAdvisors(advisor)
.build();

var chatClientResponse = chatClient.prompt()
// should be case-insensitive
.user("do you like word2?")
.advisors(advisor)
.stream()
.chatClientResponse()
.blockFirst();

assertThat(chatClientResponse.chatResponse()).isNotNull();
assertThat(chatClientResponse.chatResponse().getResult().getOutput().getText()).isEqualTo(failureResponse);
assertThat((List<String>) chatClientResponse.context().get(SafeGuardAdvisor.CONTAINS_SENSITIVE_WORDS))
.containsExactly("Word2");
}

}