diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilter.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilter.java index c68227bb50f..c784336083b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilter.java @@ -22,26 +22,28 @@ import org.springframework.ai.observation.tracing.TracingHelper; /** - * An {@link ObservationFilter} to include the chat completion content in the observation. + * {@link ObservationFilter} used to include chat completion content in the + * {@link Observation}. * * @author Thomas Vitale + * @author John Blum * @since 1.0.0 */ public class ChatModelCompletionObservationFilter implements ObservationFilter { @Override public Observation.Context map(Observation.Context context) { - if (!(context instanceof ChatModelObservationContext chatModelObservationContext)) { - return context; - } - var completions = ChatModelObservationContentProcessor.completion(chatModelObservationContext); + if (context instanceof ChatModelObservationContext chatModelObservationContext) { + + var completions = ChatModelObservationContentProcessor.completion(chatModelObservationContext); - chatModelObservationContext - .addHighCardinalityKeyValue(ChatModelObservationDocumentation.HighCardinalityKeyNames.COMPLETION - .withValue(TracingHelper.concatenateStrings(completions))); + context = chatModelObservationContext + .addHighCardinalityKeyValue(ChatModelObservationDocumentation.HighCardinalityKeyNames.COMPLETION + .withValue(TracingHelper.concatenateStrings(completions))); + } - return chatModelObservationContext; + return context; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandler.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandler.java index 9b19d4199dc..b544f9bb805 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandler.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandler.java @@ -16,16 +16,25 @@ package org.springframework.ai.chat.observation; +import java.util.Optional; + import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationHandler; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.model.observation.ModelUsageMetricsGenerator; +import org.springframework.lang.Nullable; /** - * Handler for generating metrics from chat model observations. + * {@link ObservationHandler} used to generate metrics from chat model observations. * * @author Thomas Vitale + * @author John Blum + * @see ChatModelObservationContext + * @see ObservationHandler * @since 1.0.0 */ public class ChatModelMeterObservationHandler implements ObservationHandler { @@ -38,11 +47,16 @@ public ChatModelMeterObservationHandler(MeterRegistry meterRegistry) { @Override public void onStop(ChatModelObservationContext context) { - if (context.getResponse() != null && context.getResponse().getMetadata() != null - && context.getResponse().getMetadata().getUsage() != null) { - ModelUsageMetricsGenerator.generate(context.getResponse().getMetadata().getUsage(), context, - this.meterRegistry); - } + resolveUsage(context) + .ifPresent(usage -> ModelUsageMetricsGenerator.generate(usage, context, this.meterRegistry)); + } + + private Optional resolveUsage(@Nullable ChatModelObservationContext context) { + + return Optional.ofNullable(context) + .map(ChatModelObservationContext::getResponse) + .map(ChatResponse::getMetadata) + .map(ChatResponseMetadata::getUsage); } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContentProcessor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContentProcessor.java index 3de4a321532..874e2dc7d5c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContentProcessor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContentProcessor.java @@ -16,9 +16,15 @@ package org.springframework.ai.chat.observation; +import java.util.Collections; import java.util.List; +import java.util.Optional; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.model.Content; +import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -26,36 +32,29 @@ * Utilities to process the prompt and completion content in observations for chat models. * * @author Thomas Vitale + * @author John Blum + * @since 1.0.0 */ -public final class ChatModelObservationContentProcessor { - - private ChatModelObservationContentProcessor() { - } +public abstract class ChatModelObservationContentProcessor { public static List prompt(ChatModelObservationContext context) { - if (CollectionUtils.isEmpty(context.getRequest().getInstructions())) { - return List.of(); - } - return context.getRequest().getInstructions().stream().map(Content::getContent).toList(); - } + List instructions = context.getRequest().getInstructions(); - public static List completion(ChatModelObservationContext context) { - if (context == null || context.getResponse() == null || context.getResponse().getResults() == null - || CollectionUtils.isEmpty(context.getResponse().getResults())) { - return List.of(); - } + return CollectionUtils.isEmpty(instructions) ? Collections.emptyList() + : instructions.stream().map(Content::getContent).toList(); + } - if (!StringUtils.hasText(context.getResponse().getResult().getOutput().getContent())) { - return List.of(); - } + public static List completion(@Nullable ChatModelObservationContext context) { - return context.getResponse() - .getResults() + return Optional.ofNullable(context) + .map(ChatModelObservationContext::getResponse) + .map(ChatResponse::getResults) + .orElseGet(Collections::emptyList) .stream() - .filter(generation -> generation.getOutput() != null - && StringUtils.hasText(generation.getOutput().getContent())) - .map(generation -> generation.getOutput().getContent()) + .map(Generation::getOutput) + .map(Message::getContent) + .filter(StringUtils::hasText) .toList(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilter.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilter.java index e320c9ce8bf..c8ce89356f4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilter.java @@ -25,23 +25,24 @@ * An {@link ObservationFilter} to include the chat prompt content in the observation. * * @author Thomas Vitale + * @author John Blum * @since 1.0.0 */ public class ChatModelPromptContentObservationFilter implements ObservationFilter { @Override public Observation.Context map(Observation.Context context) { - if (!(context instanceof ChatModelObservationContext chatModelObservationContext)) { - return context; - } - var prompts = ChatModelObservationContentProcessor.prompt(chatModelObservationContext); + if (context instanceof ChatModelObservationContext chatModelObservationContext) { + + var prompts = ChatModelObservationContentProcessor.prompt(chatModelObservationContext); - chatModelObservationContext - .addHighCardinalityKeyValue(ChatModelObservationDocumentation.HighCardinalityKeyNames.PROMPT - .withValue(TracingHelper.concatenateStrings(prompts))); + context = chatModelObservationContext + .addHighCardinalityKeyValue(ChatModelObservationDocumentation.HighCardinalityKeyNames.PROMPT + .withValue(TracingHelper.concatenateStrings(prompts))); + } - return chatModelObservationContext; + return context; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java index a9d4bf0504a..c6b55062cfe 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java @@ -16,12 +16,19 @@ package org.springframework.ai.chat.observation; -import java.util.Objects; +import java.util.List; +import java.util.Optional; import java.util.StringJoiner; import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; +import io.micrometer.common.docs.KeyName; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -29,17 +36,59 @@ * Default conventions to populate observations for chat model operations. * * @author Thomas Vitale + * @author John Blum * @since 1.0.0 */ public class DefaultChatModelObservationConvention implements ChatModelObservationConvention { public static final String DEFAULT_NAME = "gen_ai.client.operation"; - private static final KeyValue REQUEST_MODEL_NONE = KeyValue - .of(ChatModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL, KeyValue.NONE_VALUE); + protected static final String CONTEXTUAL_NAME_TEMPLATE = "%s %s"; - private static final KeyValue RESPONSE_MODEL_NONE = KeyValue - .of(ChatModelObservationDocumentation.LowCardinalityKeyNames.RESPONSE_MODEL, KeyValue.NONE_VALUE); + // @formatter:off + private static final ChatModelObservationDocumentation.LowCardinalityKeyNames AI_OPERATION_TYPE_KEY_NAME = + ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_OPERATION_TYPE; + + private static final ChatModelObservationDocumentation.LowCardinalityKeyNames AI_PROVIDER_KEY_NAME = + ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_PROVIDER; + + private static final ChatModelObservationDocumentation.HighCardinalityKeyNames REQUEST_FREQUENCY_PENALTY_KEY_NAME = + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY; + + private static final ChatModelObservationDocumentation.HighCardinalityKeyNames REQUEST_MAX_TOKENS_KEY_NAME = + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_MAX_TOKENS; + + private static final ChatModelObservationDocumentation.LowCardinalityKeyNames REQUEST_MODEL_KEY_NAME = + ChatModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL; + + private static final ChatModelObservationDocumentation.HighCardinalityKeyNames REQUEST_PRESENCE_PENALTY_KEY_NAME = + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY; + + private static final ChatModelObservationDocumentation.HighCardinalityKeyNames REQUEST_STOP_SEQUENCES_KEY_NAME = + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES; + + private static final ChatModelObservationDocumentation.HighCardinalityKeyNames REQUEST_TEMPERATURE_KEY_NAME = + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TEMPERATURE; + + private static final ChatModelObservationDocumentation.LowCardinalityKeyNames RESPONSE_MODEL_KEY_NAME = + ChatModelObservationDocumentation.LowCardinalityKeyNames.RESPONSE_MODEL; + + private static final ChatModelObservationDocumentation.HighCardinalityKeyNames TOP_K_KEY_NAME = + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_K; + + private static final ChatModelObservationDocumentation.HighCardinalityKeyNames TOP_P_KEY_NAME = + ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_P; + + private static final ChatModelObservationDocumentation.HighCardinalityKeyNames USAGE_TOTAL_TOKENS_KEY_NAME = + ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_TOTAL_TOKENS; + + private static final ChatModelObservationDocumentation.HighCardinalityKeyNames USAGE_OUTPUT_TOKENS_KEY_NAME = + ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS; + + private static final KeyValue REQUEST_MODEL_NONE = keyValueOf(REQUEST_MODEL_KEY_NAME); + + private static final KeyValue RESPONSE_MODEL_NONE = keyValueOf(RESPONSE_MODEL_KEY_NAME); + // @formatter:on @Override public String getName() { @@ -48,11 +97,13 @@ public String getName() { @Override public String getContextualName(ChatModelObservationContext context) { - if (StringUtils.hasText(context.getRequestOptions().getModel())) { - return "%s %s".formatted(context.getOperationMetadata().operationType(), - context.getRequestOptions().getModel()); - } - return context.getOperationMetadata().operationType(); + + return resolveRequestModelName(context).map(modelName -> getContextualName(context, modelName)) + .orElseGet(() -> context.getOperationMetadata().operationType()); + } + + private String getContextualName(ChatModelObservationContext context, String modelName) { + return CONTEXTUAL_NAME_TEMPLATE.formatted(context.getOperationMetadata().operationType(), modelName); } @Override @@ -62,35 +113,47 @@ public KeyValues getLowCardinalityKeyValues(ChatModelObservationContext context) } protected KeyValue aiOperationType(ChatModelObservationContext context) { - return KeyValue.of(ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_OPERATION_TYPE, - context.getOperationMetadata().operationType()); + return KeyValue.of(AI_OPERATION_TYPE_KEY_NAME, context.getOperationMetadata().operationType()); } protected KeyValue aiProvider(ChatModelObservationContext context) { - return KeyValue.of(ChatModelObservationDocumentation.LowCardinalityKeyNames.AI_PROVIDER, - context.getOperationMetadata().provider()); + return KeyValue.of(AI_PROVIDER_KEY_NAME, context.getOperationMetadata().provider()); } protected KeyValue requestModel(ChatModelObservationContext context) { - if (StringUtils.hasText(context.getRequestOptions().getModel())) { - return KeyValue.of(ChatModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL, - context.getRequestOptions().getModel()); - } - return REQUEST_MODEL_NONE; + + return resolveRequestModelName(context).map(modelName -> KeyValue.of(REQUEST_MODEL_KEY_NAME, modelName)) + .orElse(REQUEST_MODEL_NONE); } protected KeyValue responseModel(ChatModelObservationContext context) { - if (context.getResponse() != null && context.getResponse().getMetadata() != null - && StringUtils.hasText(context.getResponse().getMetadata().getModel())) { - return KeyValue.of(ChatModelObservationDocumentation.LowCardinalityKeyNames.RESPONSE_MODEL, - context.getResponse().getMetadata().getModel()); - } - return RESPONSE_MODEL_NONE; + + return resolveResponseModelName(context).map(modelName -> KeyValue.of(RESPONSE_MODEL_KEY_NAME, modelName)) + .orElse(RESPONSE_MODEL_NONE); + } + + private Optional resolveRequestModelName(@Nullable ChatModelObservationContext context) { + + return Optional.ofNullable(context) + .map(ChatModelObservationContext::getRequestOptions) + .map(ChatOptions::getModel) + .filter(StringUtils::hasText); + } + + private Optional resolveResponseModelName(@Nullable ChatModelObservationContext context) { + + return Optional.ofNullable(context) + .map(ChatModelObservationContext::getResponse) + .map(ChatResponse::getMetadata) + .map(ChatResponseMetadata::getModel) + .filter(StringUtils::hasText); } @Override public KeyValues getHighCardinalityKeyValues(ChatModelObservationContext context) { + var keyValues = KeyValues.empty(); + // Request keyValues = requestFrequencyPenalty(keyValues, context); keyValues = requestMaxTokens(keyValues, context); @@ -99,82 +162,74 @@ public KeyValues getHighCardinalityKeyValues(ChatModelObservationContext context keyValues = requestTemperature(keyValues, context); keyValues = requestTopK(keyValues, context); keyValues = requestTopP(keyValues, context); + // Response keyValues = responseFinishReasons(keyValues, context); keyValues = responseId(keyValues, context); keyValues = usageInputTokens(keyValues, context); keyValues = usageOutputTokens(keyValues, context); keyValues = usageTotalTokens(keyValues, context); + return keyValues; } // Request protected KeyValues requestFrequencyPenalty(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getRequestOptions().getFrequencyPenalty() != null) { - return keyValues.and( - ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), - String.valueOf(context.getRequestOptions().getFrequencyPenalty())); - } - return keyValues; + + Double frequencyPenalty = context.getRequestOptions().getFrequencyPenalty(); + + return frequencyPenalty != null + ? keyValues.and(keyValueOf(REQUEST_FREQUENCY_PENALTY_KEY_NAME, frequencyPenalty)) : keyValues; } protected KeyValues requestMaxTokens(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getRequestOptions().getMaxTokens() != null) { - return keyValues.and( - ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), - String.valueOf(context.getRequestOptions().getMaxTokens())); - } - return keyValues; + + Integer maxTokens = context.getRequestOptions().getMaxTokens(); + + return maxTokens != null ? keyValues.and(keyValueOf(REQUEST_MAX_TOKENS_KEY_NAME, maxTokens)) : keyValues; } protected KeyValues requestPresencePenalty(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getRequestOptions().getPresencePenalty() != null) { - return keyValues.and( - ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), - String.valueOf(context.getRequestOptions().getPresencePenalty())); - } - return keyValues; + + Double presencePenalty = context.getRequestOptions().getPresencePenalty(); + + return presencePenalty != null ? keyValues.and(keyValueOf(REQUEST_PRESENCE_PENALTY_KEY_NAME, presencePenalty)) + : keyValues; } protected KeyValues requestStopSequences(KeyValues keyValues, ChatModelObservationContext context) { - if (!CollectionUtils.isEmpty(context.getRequestOptions().getStopSequences())) { + + List stopSequences = context.getRequestOptions().getStopSequences(); + + if (!CollectionUtils.isEmpty(stopSequences)) { StringJoiner stopSequencesJoiner = new StringJoiner(", ", "[", "]"); - context.getRequestOptions() - .getStopSequences() - .forEach(value -> stopSequencesJoiner.add("\"" + value + "\"")); - KeyValue.of(ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES, - context.getRequestOptions().getStopSequences(), Objects::nonNull); - return keyValues.and( - ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(), - stopSequencesJoiner.toString()); + stopSequences.forEach(value -> stopSequencesJoiner.add("\"" + value + "\"")); + keyValues = keyValues.and(keyValueOf(REQUEST_STOP_SEQUENCES_KEY_NAME, stopSequencesJoiner)); } + return keyValues; } protected KeyValues requestTemperature(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getRequestOptions().getTemperature() != null) { - return keyValues.and( - ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), - String.valueOf(context.getRequestOptions().getTemperature())); - } - return keyValues; + + Double temperature = context.getRequestOptions().getTemperature(); + + return temperature != null ? keyValues.and(keyValueOf(REQUEST_TEMPERATURE_KEY_NAME, temperature)) : keyValues; } protected KeyValues requestTopK(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getRequestOptions().getTopK() != null) { - return keyValues.and(ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_K.asString(), - String.valueOf(context.getRequestOptions().getTopK())); - } - return keyValues; + + Integer topK = context.getRequestOptions().getTopK(); + + return topK != null ? keyValues.and(keyValueOf(TOP_K_KEY_NAME, topK)) : keyValues; } protected KeyValues requestTopP(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getRequestOptions().getTopP() != null) { - return keyValues.and(ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_P.asString(), - String.valueOf(context.getRequestOptions().getTopP())); - } - return keyValues; + + Double topP = context.getRequestOptions().getTopP(); + + return topP != null ? keyValues.and(keyValueOf(TOP_P_KEY_NAME, topP)) : keyValues; } // Response @@ -200,45 +255,54 @@ protected KeyValues responseFinishReasons(KeyValues keyValues, ChatModelObservat } protected KeyValues responseId(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getResponse() != null && context.getResponse().getMetadata() != null - && StringUtils.hasText(context.getResponse().getMetadata().getId())) { - return keyValues.and(ChatModelObservationDocumentation.HighCardinalityKeyNames.RESPONSE_ID.asString(), - context.getResponse().getMetadata().getId()); - } - return keyValues; + + return resolveMetadata(context).map(ChatResponseMetadata::getId) + .filter(StringUtils::hasText) + .map(id -> keyValues.and(ChatModelObservationDocumentation.HighCardinalityKeyNames.RESPONSE_ID.asString(), + id)) + .orElse(keyValues); } protected KeyValues usageInputTokens(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getResponse() != null && context.getResponse().getMetadata() != null - && context.getResponse().getMetadata().getUsage() != null - && context.getResponse().getMetadata().getUsage().getPromptTokens() != null) { - return keyValues.and( + + return resolveUsage(context).map(Usage::getPromptTokens) + .map(promptTokens -> keyValues.and( ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), - String.valueOf(context.getResponse().getMetadata().getUsage().getPromptTokens())); - } - return keyValues; + String.valueOf(promptTokens))) + .orElse(keyValues); } protected KeyValues usageOutputTokens(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getResponse() != null && context.getResponse().getMetadata() != null - && context.getResponse().getMetadata().getUsage() != null - && context.getResponse().getMetadata().getUsage().getGenerationTokens() != null) { - return keyValues.and( - ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(context.getResponse().getMetadata().getUsage().getGenerationTokens())); - } - return keyValues; + + return resolveUsage(context).map(Usage::getGenerationTokens) + .map(generatedTokens -> keyValues.and(keyValueOf(USAGE_OUTPUT_TOKENS_KEY_NAME, generatedTokens))) + .orElse(keyValues); } protected KeyValues usageTotalTokens(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getResponse() != null && context.getResponse().getMetadata() != null - && context.getResponse().getMetadata().getUsage() != null - && context.getResponse().getMetadata().getUsage().getTotalTokens() != null) { - return keyValues.and( - ChatModelObservationDocumentation.HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), - String.valueOf(context.getResponse().getMetadata().getUsage().getTotalTokens())); - } - return keyValues; + + return resolveUsage(context).map(Usage::getTotalTokens) + .map(totalTokens -> keyValues.and(keyValueOf(USAGE_TOTAL_TOKENS_KEY_NAME, totalTokens))) + .orElse(keyValues); + } + + private static KeyValue keyValueOf(KeyName keyName) { + return keyValueOf(keyName, KeyValue.NONE_VALUE); + } + + private static KeyValue keyValueOf(KeyName keyName, Object value) { + return KeyValue.of(keyName.asString(), String.valueOf(value)); + } + + private Optional resolveMetadata(@Nullable ChatModelObservationContext context) { + + return Optional.ofNullable(context) + .map(ChatModelObservationContext::getResponse) + .map(ChatResponse::getMetadata); + } + + private Optional resolveUsage(@Nullable ChatModelObservationContext context) { + return resolveMetadata(context).map(ChatResponseMetadata::getUsage); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ErrorLoggingObservationHandler.java b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ErrorLoggingObservationHandler.java index ff9e0a738e1..fd51e21e954 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ErrorLoggingObservationHandler.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/observation/ErrorLoggingObservationHandler.java @@ -31,6 +31,7 @@ /** * @author Christian Tzolov + * @author John Blum * @since 1.0.0 */ @SuppressWarnings({ "rawtypes", "null" }) @@ -62,18 +63,22 @@ public ErrorLoggingObservationHandler(Tracer tracer, } @Override + @SuppressWarnings("all") public boolean supportsContext(Context context) { - return (context == null) ? false : this.supportedContextTypes.stream().anyMatch(clz -> clz.isInstance(context)); + return context != null && this.supportedContextTypes.stream().anyMatch(clz -> clz.isInstance(context)); } @Override + @SuppressWarnings("unused") public void onError(Context context) { - if (context != null) { - TracingContext tracingContext = context.get(TracingContext.class); - if (tracingContext != null) { - try (var val = this.tracer.withSpan(tracingContext.getSpan())) { - this.errorConsumer.accept(context); - } + + Assert.notNull(context, "Context must not be null"); + + TracingContext tracingContext = context.get(TracingContext.class); + + if (tracingContext != null) { + try (var val = this.tracer.withSpan(tracingContext.getSpan())) { + this.errorConsumer.accept(context); } } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContentProcessorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContentProcessorTests.java new file mode 100644 index 00000000000..207b404c857 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContentProcessorTests.java @@ -0,0 +1,145 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.observation; + +import java.util.Collections; +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.observation.conventions.AiProvider; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Unit Tests for {@link ChatModelObservationContentProcessor}. + * + * @author John Blum + */ +@ExtendWith(MockitoExtension.class) +public class ChatModelObservationContentProcessorTests { + + @Mock + private ChatOptions mockChatOptions; + + @Mock + private Prompt mockPrompt; + + @Test + void promptReturnsListOfMessageContent() { + + Prompt prompt = spy(new Prompt(List.of(new UserMessage("user"), new SystemMessage("system")))); + + ChatModelObservationContext context = ChatModelObservationContext.builder() + .requestOptions(this.mockChatOptions) + .provider(AiProvider.OPENAI.value()) + .prompt(prompt) + .build(); + + List content = ChatModelObservationContentProcessor.prompt(context); + + assertThat(content).isNotNull().hasSize(2).containsExactly("user", "system"); + + verify(prompt, times(1)).getInstructions(); + } + + @Test + void promptWithNoMessagesReturnsEmptyList() { + + ChatModelObservationContext context = ChatModelObservationContext.builder() + .requestOptions(this.mockChatOptions) + .provider(AiProvider.OPENAI.value()) + .prompt(this.mockPrompt) + .build(); + + List content = ChatModelObservationContentProcessor.prompt(context); + + assertThat(content).isNotNull().isEmpty(); + + verify(this.mockPrompt, times(1)).getInstructions(); + } + + @Test + void completionIsNullSafe() { + + List completions = ChatModelObservationContentProcessor.completion(null); + + assertThat(completions).isNotNull().isEmpty(); + } + + @Test + @SuppressWarnings("all") + void completionsReturnsGeneratedResponse() { + + List generations = List.of(generation(""), generation("one"), generation(" "), generation("two"), + generation(null)); + + ChatResponse response = ChatResponse.builder().withGenerations(generations).build(); + + ChatModelObservationContext context = ChatModelObservationContext.builder() + .requestOptions(this.mockChatOptions) + .provider(AiProvider.OPENAI.value()) + .prompt(this.mockPrompt) + .build(); + + context.setResponse(response); + + List completions = ChatModelObservationContentProcessor.completion(context); + + assertThat(completions).isNotNull().hasSize(2).containsExactly("one", "two"); + } + + @Test + void completionsReturnsNoResponse() { + + ChatResponse response = ChatResponse.builder().withGenerations(Collections.emptyList()).build(); + + ChatModelObservationContext context = ChatModelObservationContext.builder() + .requestOptions(this.mockChatOptions) + .provider(AiProvider.OPENAI.value()) + .prompt(this.mockPrompt) + .build(); + + context.setResponse(response); + + List completions = ChatModelObservationContentProcessor.completion(context); + + assertThat(completions).isNotNull().isEmpty(); + } + + private AssistantMessage assistantMessage(String content) { + return new AssistantMessage(content); + } + + private Generation generation(String generatedContent) { + return new Generation(assistantMessage(generatedContent)); + } + +}