Skip to content

Commit f31ace1

Browse files
committed
Add Support of Mistral AI Moderation API
Signed-off-by: Ricken Bazolo <[email protected]>
1 parent 4874374 commit f31ace1

File tree

7 files changed

+526
-4
lines changed

7 files changed

+526
-4
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package org.springframework.ai.mistralai.api;
2+
3+
import com.fasterxml.jackson.annotation.JsonInclude;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import org.springframework.ai.retry.RetryUtils;
6+
import org.springframework.http.HttpHeaders;
7+
import org.springframework.http.MediaType;
8+
import org.springframework.http.ResponseEntity;
9+
import org.springframework.util.Assert;
10+
import org.springframework.web.client.ResponseErrorHandler;
11+
import org.springframework.web.client.RestClient;
12+
13+
import java.util.function.Consumer;
14+
15+
/**
16+
* MistralAI Moderation API.
17+
*
18+
* @author Ricken Bazolo
19+
* @see <a href= "https://docs.mistral.ai/capabilities/guardrailing/</a>
20+
*/
21+
public class MistralAiModerationApi {
22+
23+
private static final String DEFAULT_BASE_URL = "https://api.mistral.ai";
24+
25+
private final RestClient restClient;
26+
27+
public MistralAiModerationApi(String mistralAiApiKey) {
28+
this(DEFAULT_BASE_URL, mistralAiApiKey, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
29+
}
30+
31+
public MistralAiModerationApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder,
32+
ResponseErrorHandler responseErrorHandler) {
33+
34+
Consumer<HttpHeaders> jsonContentHeaders = headers -> {
35+
headers.setBearerAuth(mistralAiApiKey);
36+
headers.setContentType(MediaType.APPLICATION_JSON);
37+
};
38+
39+
this.restClient = restClientBuilder.baseUrl(baseUrl)
40+
.defaultHeaders(jsonContentHeaders)
41+
.defaultStatusHandler(responseErrorHandler)
42+
.build();
43+
}
44+
45+
public ResponseEntity<MistralAiModerationResponse> moderate(MistralAiModerationRequest mistralAiModerationRequest) {
46+
Assert.notNull(mistralAiModerationRequest, "Moderation request cannot be null.");
47+
Assert.hasLength(mistralAiModerationRequest.prompt(), "Prompt cannot be empty.");
48+
Assert.notNull(mistralAiModerationRequest.model(), "Model cannot be null.");
49+
50+
return this.restClient.post()
51+
.uri("v1/moderations")
52+
.body(mistralAiModerationRequest)
53+
.retrieve()
54+
.toEntity(MistralAiModerationResponse.class);
55+
}
56+
57+
public enum Model {
58+
59+
// @formatter:off
60+
MISTRAL_MODERATION("mistral-moderation-latest");
61+
// @formatter:on
62+
63+
private final String value;
64+
65+
Model(String value) {
66+
this.value = value;
67+
}
68+
69+
public String getValue() {
70+
return this.value;
71+
}
72+
73+
}
74+
75+
// @formatter:off
76+
@JsonInclude(JsonInclude.Include.NON_NULL)
77+
public record MistralAiModerationRequest(
78+
@JsonProperty("input") String prompt,
79+
@JsonProperty("model") String model
80+
) {
81+
82+
public MistralAiModerationRequest(String prompt) {
83+
this(prompt, null);
84+
}
85+
}
86+
87+
@JsonInclude(JsonInclude.Include.NON_NULL)
88+
public record MistralAiModerationResponse(
89+
@JsonProperty("id") String id,
90+
@JsonProperty("model") String model,
91+
@JsonProperty("results") MistralAiModerationResult[] results) {
92+
93+
}
94+
95+
@JsonInclude(JsonInclude.Include.NON_NULL)
96+
public record MistralAiModerationResult(
97+
@JsonProperty("categories") Categories categories,
98+
@JsonProperty("category_scores") CategoryScores categoryScores) {
99+
100+
public boolean flagged() {
101+
return categories != null && (categories.sexual() || categories.hateAndDiscrimination() || categories.violenceAndThreats()
102+
|| categories.selfHarm() || categories.dangerousAndCriminalContent() || categories.health()
103+
|| categories.financial() || categories.law() || categories.pii());
104+
}
105+
106+
}
107+
108+
@JsonInclude(JsonInclude.Include.NON_NULL)
109+
public record Categories(
110+
@JsonProperty("sexual") boolean sexual,
111+
@JsonProperty("hate_and_discrimination") boolean hateAndDiscrimination,
112+
@JsonProperty("violence_and_threats") boolean violenceAndThreats,
113+
@JsonProperty("selfharm") boolean selfHarm,
114+
@JsonProperty("dangerous_and_criminal_content") boolean dangerousAndCriminalContent,
115+
@JsonProperty("health") boolean health,
116+
@JsonProperty("financial") boolean financial,
117+
@JsonProperty("law") boolean law,
118+
@JsonProperty("pii") boolean pii) {
119+
120+
}
121+
122+
@JsonInclude(JsonInclude.Include.NON_NULL)
123+
public record CategoryScores(
124+
@JsonProperty("sexual") double sexual,
125+
@JsonProperty("hate_and_discrimination") double hateAndDiscrimination,
126+
@JsonProperty("violence_and_threats") double violenceAndThreats,
127+
@JsonProperty("selfharm") double selfHarm,
128+
@JsonProperty("dangerous_and_criminal_content") double dangerousAndCriminalContent,
129+
@JsonProperty("health") double health,
130+
@JsonProperty("financial") double financial,
131+
@JsonProperty("law") double law,
132+
@JsonProperty("pii") double pii) {
133+
134+
}
135+
// @formatter:onn
136+
137+
}
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package org.springframework.ai.mistralai.moderation;
2+
3+
import org.slf4j.Logger;
4+
import org.slf4j.LoggerFactory;
5+
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
6+
import org.springframework.ai.model.ModelOptionsUtils;
7+
import org.springframework.ai.moderation.*;
8+
import org.springframework.ai.retry.RetryUtils;
9+
import org.springframework.http.ResponseEntity;
10+
import org.springframework.retry.support.RetryTemplate;
11+
import org.springframework.util.Assert;
12+
13+
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationRequest;
14+
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResponse;
15+
import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResult;
16+
17+
import java.util.ArrayList;
18+
import java.util.List;
19+
20+
/**
21+
* @author Ricken Bazolo
22+
*/
23+
public class MistralAiModerationModel implements ModerationModel {
24+
25+
private final Logger logger = LoggerFactory.getLogger(getClass());
26+
27+
private final MistralAiModerationApi mistralAiModerationApi;
28+
29+
private final RetryTemplate retryTemplate;
30+
31+
private final MistralAiModerationOptions defaultOptions;
32+
33+
public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi) {
34+
this(mistralAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE,
35+
MistralAiModerationOptions.builder()
36+
.model(MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue())
37+
.build());
38+
}
39+
40+
public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, MistralAiModerationOptions options) {
41+
this(mistralAiModerationApi, RetryUtils.DEFAULT_RETRY_TEMPLATE, options);
42+
}
43+
44+
public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, RetryTemplate retryTemplate,
45+
MistralAiModerationOptions options) {
46+
Assert.notNull(mistralAiModerationApi, "mistralAiModerationApi must not be null");
47+
Assert.notNull(retryTemplate, "retryTemplate must not be null");
48+
Assert.notNull(options, "options must not be null");
49+
this.mistralAiModerationApi = mistralAiModerationApi;
50+
this.retryTemplate = retryTemplate;
51+
this.defaultOptions = options;
52+
}
53+
54+
@Override
55+
public ModerationResponse call(ModerationPrompt moderationPrompt) {
56+
return this.retryTemplate.execute(ctx -> {
57+
58+
var instructions = moderationPrompt.getInstructions().getText();
59+
60+
var moderationRequest = new MistralAiModerationRequest(instructions);
61+
62+
if (this.defaultOptions != null) {
63+
moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest,
64+
MistralAiModerationRequest.class);
65+
}
66+
else {
67+
// moderationPrompt.getOptions() never null but model can be empty, cause
68+
// by ModerationPrompt constructor
69+
moderationRequest = ModelOptionsUtils.merge(toMistralAiModerationOptions(moderationPrompt.getOptions()),
70+
moderationRequest, MistralAiModerationRequest.class);
71+
}
72+
73+
var moderationResponseEntity = this.mistralAiModerationApi.moderate(moderationRequest);
74+
75+
return convertResponse(moderationResponseEntity, moderationRequest);
76+
});
77+
}
78+
79+
private ModerationResponse convertResponse(ResponseEntity<MistralAiModerationResponse> moderationResponseEntity,
80+
MistralAiModerationRequest openAiModerationRequest) {
81+
var moderationApiResponse = moderationResponseEntity.getBody();
82+
if (moderationApiResponse == null) {
83+
logger.warn("No moderation response returned for request: {}", openAiModerationRequest);
84+
return new ModerationResponse(new Generation());
85+
}
86+
87+
List<ModerationResult> moderationResults = new ArrayList<>();
88+
if (moderationApiResponse.results() != null) {
89+
90+
for (MistralAiModerationResult result : moderationApiResponse.results()) {
91+
Categories categories = null;
92+
CategoryScores categoryScores = null;
93+
if (result.categories() != null) {
94+
categories = Categories.builder()
95+
.sexual(result.categories().sexual())
96+
.pii(result.categories().pii())
97+
.law(result.categories().law())
98+
.financial(result.categories().financial())
99+
.health(result.categories().health())
100+
.dangerousAndCriminalContent(result.categories().dangerousAndCriminalContent())
101+
.violence(result.categories().violenceAndThreats())
102+
.hate(result.categories().hateAndDiscrimination())
103+
.selfHarm(result.categories().selfHarm())
104+
.build();
105+
}
106+
if (result.categoryScores() != null) {
107+
categoryScores = CategoryScores.builder()
108+
.sexual(result.categoryScores().sexual())
109+
.pii(result.categoryScores().pii())
110+
.law(result.categoryScores().law())
111+
.financial(result.categoryScores().financial())
112+
.health(result.categoryScores().health())
113+
.dangerousAndCriminalContent(result.categoryScores().dangerousAndCriminalContent())
114+
.violence(result.categoryScores().violenceAndThreats())
115+
.hate(result.categoryScores().hateAndDiscrimination())
116+
.selfHarm(result.categoryScores().selfHarm())
117+
.build();
118+
}
119+
var moderationResult = ModerationResult.builder()
120+
.categories(categories)
121+
.categoryScores(categoryScores)
122+
.flagged(result.flagged())
123+
.build();
124+
moderationResults.add(moderationResult);
125+
}
126+
127+
}
128+
129+
var moderation = Moderation.builder()
130+
.id(moderationApiResponse.id())
131+
.model(moderationApiResponse.model())
132+
.results(moderationResults)
133+
.build();
134+
135+
return new ModerationResponse(new Generation(moderation));
136+
}
137+
138+
private MistralAiModerationOptions toMistralAiModerationOptions(ModerationOptions runtimeModerationOptions) {
139+
var mistralAiModerationOptionsBuilder = MistralAiModerationOptions.builder();
140+
if (runtimeModerationOptions != null && runtimeModerationOptions.getModel() != null) {
141+
mistralAiModerationOptionsBuilder.model(runtimeModerationOptions.getModel());
142+
}
143+
return mistralAiModerationOptionsBuilder.build();
144+
}
145+
146+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package org.springframework.ai.mistralai.moderation;
2+
3+
import com.fasterxml.jackson.annotation.JsonInclude;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import org.springframework.ai.mistralai.api.MistralAiModerationApi;
6+
import org.springframework.ai.moderation.ModerationOptions;
7+
8+
/**
9+
* @author Ricken Bazolo
10+
*/
11+
@JsonInclude(JsonInclude.Include.NON_NULL)
12+
public class MistralAiModerationOptions implements ModerationOptions {
13+
14+
private static final String DEFAULT_MODEL = MistralAiModerationApi.Model.MISTRAL_MODERATION.getValue();
15+
16+
/**
17+
* The model to use for moderation generation.
18+
*/
19+
@JsonProperty("model")
20+
private String model = DEFAULT_MODEL;
21+
22+
public static Builder builder() {
23+
return new Builder();
24+
}
25+
26+
@Override
27+
public String getModel() {
28+
return this.model;
29+
}
30+
31+
public void setModel(String model) {
32+
this.model = model;
33+
}
34+
35+
public static final class Builder {
36+
37+
private final MistralAiModerationOptions options;
38+
39+
private Builder() {
40+
this.options = new MistralAiModerationOptions();
41+
}
42+
43+
public Builder model(String model) {
44+
this.options.setModel(model);
45+
return this;
46+
}
47+
48+
public MistralAiModerationOptions build() {
49+
return this.options;
50+
}
51+
52+
}
53+
54+
}

0 commit comments

Comments
 (0)