diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index a75a274a797..8d22df6ddcc 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -84,6 +84,7 @@ * @author Jihoon Kim * @author Alexandros Pappas * @author Ilayaperumal Gopinathan + * @author Sun Yuhan * @since 1.0.0 */ public class OllamaChatModel implements ChatModel { @@ -170,18 +171,21 @@ static ChatResponseMetadata from(OllamaApi.ChatResponse response, ChatResponse p Duration totalDuration = response.getTotalDuration(); if (previousChatResponse != null && previousChatResponse.getMetadata() != null) { - if (previousChatResponse.getMetadata().get(METADATA_EVAL_DURATION) != null) { - evalDuration = evalDuration.plus(previousChatResponse.getMetadata().get(METADATA_EVAL_DURATION)); + Object metadataEvalDuration = previousChatResponse.getMetadata().get(METADATA_EVAL_DURATION); + if (metadataEvalDuration != null && evalDuration != null) { + evalDuration = evalDuration.plus((Duration) metadataEvalDuration); } - if (previousChatResponse.getMetadata().get(METADATA_PROMPT_EVAL_DURATION) != null) { - promptEvalDuration = promptEvalDuration - .plus(previousChatResponse.getMetadata().get(METADATA_PROMPT_EVAL_DURATION)); + Object metadataPromptEvalDuration = previousChatResponse.getMetadata().get(METADATA_PROMPT_EVAL_DURATION); + if (metadataPromptEvalDuration != null && promptEvalDuration != null) { + promptEvalDuration = promptEvalDuration.plus((Duration) metadataPromptEvalDuration); } - if (previousChatResponse.getMetadata().get(METADATA_LOAD_DURATION) != null) { - loadDuration = loadDuration.plus(previousChatResponse.getMetadata().get(METADATA_LOAD_DURATION)); + Object metadataLoadDuration = previousChatResponse.getMetadata().get(METADATA_LOAD_DURATION); + if (metadataLoadDuration != null && loadDuration != null) { + loadDuration = loadDuration.plus((Duration) metadataLoadDuration); } - if (previousChatResponse.getMetadata().get(METADATA_TOTAL_DURATION) != null) { - totalDuration = totalDuration.plus(previousChatResponse.getMetadata().get(METADATA_TOTAL_DURATION)); + Object metadataTotalDuration = previousChatResponse.getMetadata().get(METADATA_TOTAL_DURATION); + if (metadataTotalDuration != null && totalDuration != null) { + totalDuration = totalDuration.plus((Duration) metadataTotalDuration); } if (previousChatResponse.getMetadata().getUsage() != null) { promptTokens += previousChatResponse.getMetadata().getUsage().getPromptTokens(); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java index bbeb101fc9a..fdb3c43cb68 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java @@ -37,8 +37,7 @@ import org.springframework.ai.ollama.management.ModelManagementOptions; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.*; /** * @author Jihoon Kim @@ -146,4 +145,28 @@ void buildChatResponseMetadataAggregationWithNonEmptyMetadata() { assertEquals(promptEvalCount + 66, (Integer) metadata.get("prompt-eval-count")); } + @Test + void buildChatResponseMetadataAggregationWithNonEmptyMetadataButEmptyEval() { + + OllamaApi.ChatResponse response = new OllamaApi.ChatResponse("model", Instant.now(), null, null, null, null, + null, null, null, null, null); + + ChatResponse previousChatResponse = ChatResponse.builder() + .generations(List.of()) + .metadata(ChatResponseMetadata.builder() + .usage(new DefaultUsage(66, 99)) + .keyValue("eval-duration", Duration.ofSeconds(2)) + .keyValue("prompt-eval-duration", Duration.ofSeconds(2)) + .build()) + .build(); + + ChatResponseMetadata metadata = OllamaChatModel.from(response, previousChatResponse); + + assertNull(metadata.get("eval-duration")); + assertNull(metadata.get("prompt-eval-duration")); + assertEquals(Integer.valueOf(99), metadata.get("eval-count")); + assertEquals(Integer.valueOf(66), metadata.get("prompt-eval-count")); + + } + }