diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 6660dbebfc2..c197e2877de 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -565,6 +565,7 @@ private ChatGenerationMetadata generateChoiceMetadata(ChatChoice choice) { return ChatGenerationMetadata.builder() .finishReason(String.valueOf(choice.getFinishReason())) .metadata("contentFilterResults", choice.getContentFilterResults()) + .metadata("logprobs", choice.getLogprobs()) .build(); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java index 97e95bebb70..d10bd355bf0 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatModelMetadataTests.java @@ -18,10 +18,7 @@ import java.nio.charset.StandardCharsets; -import com.azure.ai.openai.models.ContentFilterResult; -import com.azure.ai.openai.models.ContentFilterResultDetailsForPrompt; -import com.azure.ai.openai.models.ContentFilterResultsForChoice; -import com.azure.ai.openai.models.ContentFilterSeverity; +import com.azure.ai.openai.models.*; import org.junit.jupiter.api.Test; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; @@ -129,6 +126,58 @@ private void assertChoiceMetadata(Generation generation) { assertThat(chatGenerationMetadata).isNotNull(); assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("stop"); assertContentFilterResults(chatGenerationMetadata.get("contentFilterResults")); + assertLogprobs(chatGenerationMetadata.get("logprobs")); + } + + private static void assertLogprobs(ChatChoiceLogProbabilityInfo logprobsInfo) { + assertThat(logprobsInfo.getContent()).hasSize(9); + assertLogprobResult(logprobsInfo.getContent().get(0), -0.0009114635, "Hello", 72, 101, 108, 108, 111); + assertThat(logprobsInfo.getContent().get(0).getTopLogprobs()).hasSize(3); + + assertLogprobResult(logprobsInfo.getContent().get(1), -0.0000019816675, "!", 33); + assertThat(logprobsInfo.getContent().get(1).getTopLogprobs()).hasSize(3); + + assertLogprobResult(logprobsInfo.getContent().get(2), -3.1281633e-7, " How", 32, 72, 111, 119); + assertThat(logprobsInfo.getContent().get(2).getTopLogprobs()).hasSize(3); + + assertLogprobResult(logprobsInfo.getContent().get(3), -0.0000079418505, " can", 32, 99, 97, 110); + assertThat(logprobsInfo.getContent().get(3).getTopLogprobs()).hasSize(3); + + assertLogprobResult(logprobsInfo.getContent().get(4), 0, " I", 32, 73); + assertThat(logprobsInfo.getContent().get(4).getTopLogprobs()).hasSize(3); + + assertLogprobResult(logprobsInfo.getContent().get(5), -0.0010328111, " assist", 32, 97, 115, 115, 105, 115, + 116); + assertThat(logprobsInfo.getContent().get(5).getTopLogprobs()).hasSize(3); + + assertLogprobResult(logprobsInfo.getContent().get(6), 0, " you", 32, 121, 111, 117); + assertThat(logprobsInfo.getContent().get(6).getTopLogprobs()).hasSize(3); + + assertLogprobResult(logprobsInfo.getContent().get(7), 0, " today", 32, 116, 111, 100, 97, 121); + assertThat(logprobsInfo.getContent().get(7).getTopLogprobs()).hasSize(3); + + assertLogprobResult(logprobsInfo.getContent().get(8), -0.0000023392786, "?", 63); + assertThat(logprobsInfo.getContent().get(8).getTopLogprobs()).hasSize(3); + + assertLogprobInfo(logprobsInfo.getContent().get(0).getTopLogprobs().get(0), -0.0009114635, "Hello", 72, 101, + 108, 108, 111); + assertLogprobInfo(logprobsInfo.getContent().get(0).getTopLogprobs().get(1), -7.000911, "Hi", 72, 105); + assertLogprobInfo(logprobsInfo.getContent().get(0).getTopLogprobs().get(2), -19.875912, "Hey", 72, 101, 121); + + } + + private static void assertLogprobResult(ChatTokenLogProbabilityResult actual, double expectedLogprob, + String expectedToken, Integer... expectedBytes) { + assertThat(actual.getLogprob()).isEqualTo(expectedLogprob); + assertThat(actual.getBytes()).contains(expectedBytes); + assertThat(actual.getToken()).isEqualTo(expectedToken); + } + + private static void assertLogprobInfo(ChatTokenLogProbabilityInfo actual, double expectedLogprob, + String expectedToken, Integer... expectedBytes) { + assertThat(actual.getLogprob()).isEqualTo(expectedLogprob); + assertThat(actual.getBytes()).contains(expectedBytes); + assertThat(actual.getToken()).isEqualTo(expectedToken); } private void assertContentFilterResultsForPrompt(ContentFilterResultDetailsForPrompt contentFilterResultForPrompt, @@ -231,6 +280,384 @@ private String getJson() { } }, "finish_reason": "stop", + "index": 0, + "logprobs": { + "content": [ + { + "bytes": [ + 72, + 101, + 108, + 108, + 111 + ], + "logprob": -0.0009114635, + "token": "Hello", + "top_logprobs": [ + { + "bytes": [ + 72, + 101, + 108, + 108, + 111 + ], + "logprob": -0.0009114635, + "token": "Hello" + }, + { + "bytes": [ + 72, + 105 + ], + "logprob": -7.000911, + "token": "Hi" + }, + { + "bytes": [ + 72, + 101, + 121 + ], + "logprob": -19.875912, + "token": "Hey" + } + ] + }, + { + "bytes": [ + 33 + ], + "logprob": -0.0000019816675, + "token": "!", + "top_logprobs": [ + { + "bytes": [ + 33 + ], + "logprob": -0.0000019816675, + "token": "!" + }, + { + "bytes": [ + 32, + 116, + 104, + 101, + 114, + 101 + ], + "logprob": -13.187502, + "token": " there" + }, + { + "bytes": [ + 46 + ], + "logprob": -20.687502, + "token": "." + } + ] + }, + { + "bytes": [ + 32, + 72, + 111, + 119 + ], + "logprob": -3.1281633e-7, + "token": " How", + "top_logprobs": [ + { + "bytes": [ + 32, + 72, + 111, + 119 + ], + "logprob": -3.1281633e-7, + "token": " How" + }, + { + "bytes": [ + 32, + 87, + 104, + 97, + 116 + ], + "logprob": -15.125, + "token": " What" + }, + { + "bytes": [ + 32, + 104, + 111, + 119 + ], + "logprob": -20.75, + "token": " how" + } + ] + }, + { + "bytes": [ + 32, + 99, + 97, + 110 + ], + "logprob": -0.0000079418505, + "token": " can", + "top_logprobs": [ + { + "bytes": [ + 32, + 99, + 97, + 110 + ], + "logprob": -0.0000079418505, + "token": " can" + }, + { + "bytes": [ + 32, + 109, + 97, + 121 + ], + "logprob": -11.750008, + "token": " may" + }, + { + "bytes": [ + 32, + 109, + 105, + 103, + 104, + 116 + ], + "logprob": -21.250008, + "token": " might" + } + ] + }, + { + "bytes": [ + 32, + 73 + ], + "logprob": 0, + "token": " I", + "top_logprobs": [ + { + "bytes": [ + 32, + 73 + ], + "logprob": 0, + "token": " I" + }, + { + "bytes": [ + 32, + 97, + 115, + 115, + 105, + 115, + 116 + ], + "logprob": -24.75, + "token": " assist" + }, + { + "bytes": [ + 73 + ], + "logprob": -25.875, + "token": "I" + } + ] + }, + { + "bytes": [ + 32, + 97, + 115, + 115, + 105, + 115, + 116 + ], + "logprob": -0.0010328111, + "token": " assist", + "top_logprobs": [ + { + "bytes": [ + 32, + 97, + 115, + 115, + 105, + 115, + 116 + ], + "logprob": -0.0010328111, + "token": " assist" + }, + { + "bytes": [ + 32, + 104, + 101, + 108, + 112 + ], + "logprob": -6.876033, + "token": " help" + }, + { + "bytes": [ + 97, + 115, + 115, + 105, + 115, + 116 + ], + "logprob": -18.251032, + "token": "assist" + } + ] + }, + { + "bytes": [ + 32, + 121, + 111, + 117 + ], + "logprob": 0, + "token": " you", + "top_logprobs": [ + { + "bytes": [ + 32, + 121, + 111, + 117 + ], + "logprob": 0, + "token": " you" + }, + { + "bytes": [ + 32, + 118, + 111, + 99, + 195, + 170 + ], + "logprob": -26.625, + "token": " vocĂȘ" + }, + { + "bytes": [ + 121, + 111, + 117 + ], + "logprob": -26.75, + "token": "you" + } + ] + }, + { + "bytes": [ + 32, + 116, + 111, + 100, + 97, + 121 + ], + "logprob": 0, + "token": " today", + "top_logprobs": [ + { + "bytes": [ + 32, + 116, + 111, + 100, + 97, + 121 + ], + "logprob": 0, + "token": " today" + }, + { + "bytes": [ + 63 + ], + "logprob": -21.375, + "token": "?" + }, + { + "bytes": [ + 32, + 116, + 111, + 100, + 97 + ], + "logprob": -25.25, + "token": " toda" + } + ] + }, + { + "bytes": [ + 63 + ], + "logprob": -0.0000023392786, + "token": "?", + "top_logprobs": [ + { + "bytes": [ + 63 + ], + "logprob": -0.0000023392786, + "token": "?" + }, + { + "bytes": [ + 63, + 10 + ], + "logprob": -13.000002, + "token": "?\\n" + }, + { + "bytes": [ + 63, + 10, + 10 + ], + "logprob": -16.750002, + "token": "?\\n\\n" + } + ] + } + ], + "refusal": null + }, "message":{ "role": "user", "content": "No! You will actually land with a resounding thud. This is the way!"