diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java index d7d1a0d852b..3df1766ec08 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java @@ -18,10 +18,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.springframework.ai.mistralai.moderation.MistralAiModerationModel; +import org.springframework.ai.moderation.CategoryScores; import org.springframework.ai.moderation.Moderation; import org.springframework.ai.moderation.ModerationPrompt; import org.springframework.ai.moderation.ModerationResult; @@ -32,13 +31,12 @@ /** * @author Ricken Bazolo + * @author Jonghoon Park */ @SpringBootTest(classes = MistralAiTestConfiguration.class) @EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+") public class MistralAiModerationModelIT { - private static final Logger logger = LoggerFactory.getLogger(MistralAiModerationModelIT.class); - @Autowired private MistralAiModerationModel mistralAiModerationModel; @@ -58,14 +56,23 @@ void moderationAsPositiveTest() { assertThat(moderation.getId()).isNotEmpty(); assertThat(moderation.getResults()).isNotNull(); assertThat(moderation.getResults().size()).isNotZero(); - logger.info(moderation.getResults().toString()); assertThat(moderation.getId()).isNotNull(); assertThat(moderation.getModel()).isNotNull(); ModerationResult result = moderation.getResults().get(0); assertThat(result.isFlagged()).isTrue(); - assertThat(result.getCategories().isViolence()).isTrue(); + + CategoryScores scores = result.getCategoryScores(); + assertThat(scores.getSexual()).isNotNull(); + assertThat(scores.getHate()).isNotNull(); + assertThat(scores.getViolence()).isNotNull(); + assertThat(scores.getDangerousAndCriminalContent()).isNotNull(); + assertThat(scores.getSelfHarm()).isNotNull(); + assertThat(scores.getHealth()).isNotNull(); + assertThat(scores.getFinancial()).isNotNull(); + assertThat(scores.getLaw()).isNotNull(); + assertThat(scores.getPii()).isNotNull(); } } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/moderation/CategoryScores.java b/spring-ai-model/src/main/java/org/springframework/ai/moderation/CategoryScores.java index 4a9c8d1a2d2..dd458712ac5 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/moderation/CategoryScores.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/moderation/CategoryScores.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ * @author Ahmed Yousri * @author Ilayaperumal Gopinathan * @author Ricken Bazolo + * @author Jonghoon Park * @since 1.0.0 */ public final class CategoryScores { @@ -129,6 +130,26 @@ public double getViolence() { return this.violence; } + public double getDangerousAndCriminalContent() { + return dangerousAndCriminalContent; + } + + public double getHealth() { + return health; + } + + public double getFinancial() { + return financial; + } + + public double getLaw() { + return law; + } + + public double getPii() { + return pii; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -147,14 +168,18 @@ public boolean equals(Object o) { && Double.compare(that.selfHarmIntent, this.selfHarmIntent) == 0 && Double.compare(that.selfHarmInstructions, this.selfHarmInstructions) == 0 && Double.compare(that.harassmentThreatening, this.harassmentThreatening) == 0 - && Double.compare(that.violence, this.violence) == 0; + && Double.compare(that.violence, this.violence) == 0 + && Double.compare(that.dangerousAndCriminalContent, this.dangerousAndCriminalContent) == 0 + && Double.compare(that.health, this.health) == 0 && Double.compare(that.financial, this.financial) == 0 + && Double.compare(that.law, this.law) == 0 && Double.compare(that.pii, this.pii) == 0; } @Override public int hashCode() { return Objects.hash(this.sexual, this.hate, this.harassment, this.selfHarm, this.sexualMinors, this.hateThreatening, this.violenceGraphic, this.selfHarmIntent, this.selfHarmInstructions, - this.harassmentThreatening, this.violence); + this.harassmentThreatening, this.violence, this.dangerousAndCriminalContent, this.health, + this.financial, this.law, this.pii); } @Override @@ -163,7 +188,9 @@ public String toString() { + ", selfHarm=" + this.selfHarm + ", sexualMinors=" + this.sexualMinors + ", hateThreatening=" + this.hateThreatening + ", violenceGraphic=" + this.violenceGraphic + ", selfHarmIntent=" + this.selfHarmIntent + ", selfHarmInstructions=" + this.selfHarmInstructions - + ", harassmentThreatening=" + this.harassmentThreatening + ", violence=" + this.violence + '}'; + + ", harassmentThreatening=" + this.harassmentThreatening + ", violence=" + this.violence + + ", dangerousAndCriminalContent=" + dangerousAndCriminalContent + ", health=" + health + ", financial=" + + financial + ", law=" + law + ", pii=" + pii + '}'; } public static class Builder {