|
18 | 18 |
|
19 | 19 | import java.util.List; |
20 | 20 | import java.util.Map; |
| 21 | +import java.util.Set; |
| 22 | +import java.util.stream.Collectors; |
21 | 23 |
|
| 24 | +import org.slf4j.Logger; |
| 25 | +import org.slf4j.LoggerFactory; |
22 | 26 | import reactor.core.publisher.Flux; |
23 | 27 |
|
24 | 28 | import org.springframework.ai.chat.client.ChatClientRequest; |
|
44 | 48 | */ |
45 | 49 | public class SafeGuardAdvisor implements CallAdvisor, StreamAdvisor { |
46 | 50 |
|
| 51 | + private static final Logger log = LoggerFactory.getLogger(SafeGuardAdvisor.class); |
| 52 | + |
47 | 53 | private static final String DEFAULT_FAILURE_RESPONSE = "I'm unable to respond to that due to sensitive content. Could we rephrase or discuss something else?"; |
48 | 54 |
|
49 | 55 | private static final int DEFAULT_ORDER = 0; |
@@ -76,22 +82,28 @@ public String getName() { |
76 | 82 |
|
77 | 83 | @Override |
78 | 84 | public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) { |
79 | | - if (!CollectionUtils.isEmpty(this.sensitiveWords) |
80 | | - && this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) { |
81 | | - return createFailureResponse(chatClientRequest); |
| 85 | + if (!CollectionUtils.isEmpty(this.sensitiveWords)) { |
| 86 | + String content = chatClientRequest.prompt().getContents(); |
| 87 | + Set<String> hitWords = this.sensitiveWords.stream().filter(content::contains).collect(Collectors.toSet()); |
| 88 | + if (!hitWords.isEmpty()) { |
| 89 | + log.debug("sensitive words found: {}", hitWords); |
| 90 | + return createFailureResponse(chatClientRequest); |
| 91 | + } |
82 | 92 | } |
83 | | - |
84 | 93 | return callAdvisorChain.nextCall(chatClientRequest); |
85 | 94 | } |
86 | 95 |
|
87 | 96 | @Override |
88 | 97 | public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, |
89 | 98 | StreamAdvisorChain streamAdvisorChain) { |
90 | | - if (!CollectionUtils.isEmpty(this.sensitiveWords) |
91 | | - && this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) { |
92 | | - return Flux.just(createFailureResponse(chatClientRequest)); |
| 99 | + if (!CollectionUtils.isEmpty(this.sensitiveWords)) { |
| 100 | + String contents = chatClientRequest.prompt().getContents(); |
| 101 | + Set<String> hitWords = this.sensitiveWords.stream().filter(contents::contains).collect(Collectors.toSet()); |
| 102 | + if (!hitWords.isEmpty()) { |
| 103 | + log.debug("sensitive words found: {}", hitWords); |
| 104 | + return Flux.just(createFailureResponse(chatClientRequest)); |
| 105 | + } |
93 | 106 | } |
94 | | - |
95 | 107 | return streamAdvisorChain.nextStream(chatClientRequest); |
96 | 108 | } |
97 | 109 |
|
|
0 commit comments