Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@

import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import reactor.core.publisher.Flux;
Expand All @@ -37,10 +40,26 @@
*/
public class SafeGuardAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {

private final static String DEFAULT_FAILURE_RESPONSE = "I'm unable to respond to that due to sensitive content. Could we rephrase or discuss something else?";

private final static int DEFAULT_ORDER = 0;

private final String failureResponse;

private final List<String> sensitiveWords;

private final int order;

public SafeGuardAdvisor(List<String> sensitiveWords) {
this(sensitiveWords, DEFAULT_FAILURE_RESPONSE, DEFAULT_ORDER);
}

public SafeGuardAdvisor(List<String> sensitiveWords, String failureResponse, int order) {
Assert.notNull(sensitiveWords, "Sensitive words must not be null!");
Assert.notNull(failureResponse, "Failure response must not be null!");
this.sensitiveWords = sensitiveWords;
this.failureResponse = failureResponse;
this.order = order;
}

public String getName() {
Expand All @@ -51,9 +70,9 @@ public String getName() {
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {

if (!CollectionUtils.isEmpty(this.sensitiveWords)
&& sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) {
return new AdvisedResponse(ChatResponse.builder().withGenerations(List.of()).build(),
advisedRequest.adviseContext());
&& this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) {

return createFailureResponse(advisedRequest);
}

return chain.nextAroundCall(advisedRequest);
Expand All @@ -64,16 +83,57 @@ public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamA

if (!CollectionUtils.isEmpty(this.sensitiveWords)
&& sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) {
return Flux.empty();
return Flux.just(createFailureResponse(advisedRequest));
}

return chain.nextAroundStream(advisedRequest);
}

private AdvisedResponse createFailureResponse(AdvisedRequest advisedRequest) {
return new AdvisedResponse(ChatResponse.builder()
.withGenerations(List.of(new Generation(new AssistantMessage(this.failureResponse))))
.build(), advisedRequest.adviseContext());
}

@Override
public int getOrder() {
return 0;
return this.order;
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private List<String> sensitiveWords;

private String failureResponse = DEFAULT_FAILURE_RESPONSE;

private int order = DEFAULT_ORDER;

private Builder() {
}

public Builder withSensitiveWords(List<String> sensitiveWords) {
this.sensitiveWords = sensitiveWords;
return this;
}

public Builder withFailureResponse(String failureResponse) {
this.failureResponse = failureResponse;
return this;
}

public Builder withOrder(int order) {
this.order = order;
return this;
}

public SafeGuardAdvisor build() {
return new SafeGuardAdvisor(this.sensitiveWords, this.failureResponse, this.order);
}

}

}