diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index c6bd6c2676e..628b349de3f 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -113,6 +113,12 @@ public class OllamaChatModel implements ChatModel { private static final String METADATA_EVAL_DURATION = "eval-duration"; + private static final String METADATA_SOURCE = "source"; + + private static final String SOURCE_MODEL = "model"; + + private static final String SOURCE_TOOL = "tool"; + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); @@ -264,6 +270,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) { generationMetadata = ChatGenerationMetadata.builder() .finishReason(ollamaResponse.doneReason()) + .metadata(METADATA_SOURCE, SOURCE_MODEL) .build(); } @@ -340,7 +347,10 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL; if (chunk.promptEvalCount() != null && chunk.evalCount() != null) { - generationMetadata = ChatGenerationMetadata.builder().finishReason(chunk.doneReason()).build(); + generationMetadata = ChatGenerationMetadata.builder() + .finishReason(chunk.doneReason()) + .metadata(METADATA_SOURCE, SOURCE_MODEL) + .build(); } var generator = new Generation(assistantMessage, generationMetadata); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java index ae79f67350e..5af2269ecf2 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java @@ -341,6 +341,23 @@ void chatMemoryWithTools() { assertThat(newResponse.getResult().getOutput().getText()).contains("6").contains("8"); } + @Test + void shouldProvideSourceMetadataInRealResponses() { + String userMessageContent = "Hello, can you respond with a simple greeting?"; + ChatResponse response = this.chatModel.call(new Prompt(userMessageContent)); + + assertThat(response).isNotNull(); + assertThat(response.getResults()).hasSize(1); + + Generation generation = response.getResult(); + assertThat(generation).isNotNull(); + assertThat(generation.getOutput().getText()).isNotEmpty(); + + // Verify that source metadata is present and set to "model" + String source = (String) generation.getMetadata().get("source"); + assertThat(source).isEqualTo("model"); + } + static class MathTools { @Tool(description = "Multiply the two numbers") diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java index 1cb17781b0e..ab7571e9198 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java @@ -171,4 +171,32 @@ void buildChatResponseMetadataAggregationWithNonEmptyMetadataButEmptyEval() { } + @Test + void shouldAddSourceMetadataForModelResponses() { + // Create a mock OllamaApi.ChatResponse with eval counts to trigger metadata + // creation + Long evalDuration = 1000L; + Integer evalCount = 101; + Integer promptEvalCount = 808; + Long promptEvalDuration = 8L; + Long loadDuration = 100L; + Long totalDuration = 2000L; + + OllamaApi.ChatResponse response = new OllamaApi.ChatResponse("model", Instant.now(), + new OllamaApi.Message(OllamaApi.Message.Role.ASSISTANT, "Test response", null, null), "stop", true, + totalDuration, loadDuration, promptEvalCount, promptEvalDuration, evalCount, evalDuration); + + ChatResponseMetadata metadata = OllamaChatModel.from(response, null); + + // Verify that basic metadata fields are present + assertEquals(Duration.ofNanos(evalDuration), metadata.get("eval-duration")); + assertEquals(evalCount, metadata.get("eval-count")); + + // Test that source metadata is NOT added by the from() method (this is handled in + // Generation creation) + // The from() method only creates response-level metadata, not generation-level + // metadata + assertNull(metadata.get("source")); + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionResult.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionResult.java index 1b2df957dd1..41fae0ee594 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionResult.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionResult.java @@ -39,6 +39,10 @@ public interface ToolExecutionResult { String METADATA_TOOL_NAME = "toolName"; + String METADATA_SOURCE = "source"; + + String SOURCE_TOOL = "tool"; + /** * The history of messages exchanged during the conversation, including the tool * execution result. @@ -75,6 +79,7 @@ static List buildGenerations(ToolExecutionResult toolExecutionResult ChatGenerationMetadata.builder() .metadata(METADATA_TOOL_ID, response.id()) .metadata(METADATA_TOOL_NAME, response.name()) + .metadata(METADATA_SOURCE, SOURCE_TOOL) .finishReason(FINISH_REASON) .build()); generations.add(generation); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java index acb3bfca0c5..ba0a5b5e219 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionResultTests.java @@ -80,4 +80,29 @@ void whenMultipleToolCallsThenMultipleGenerations() { assertThat(generations.get(1).getMetadata().getFinishReason()).isEqualTo(ToolExecutionResult.FINISH_REASON); } + @Test + void shouldAddSourceMetadataForToolResponses() { + var toolExecutionResult = ToolExecutionResult.builder() + .conversationHistory(List.of(new AssistantMessage("Hello, how can I help you?"), + new UserMessage("I would like to know the weather in London"), + new AssistantMessage("Call the weather tool"), + new ToolResponseMessage(List.of(new ToolResponseMessage.ToolResponse("42", "weather", + "The weather in London is 20 degrees Celsius"))))) + .build(); + + var generations = ToolExecutionResult.buildGenerations(toolExecutionResult); + + assertThat(generations).hasSize(1); + var generation = generations.get(0); + + // Verify existing metadata fields are still present + assertThat((String) generation.getMetadata().get(ToolExecutionResult.METADATA_TOOL_NAME)).isEqualTo("weather"); + assertThat((String) generation.getMetadata().get(ToolExecutionResult.METADATA_TOOL_ID)).isEqualTo("42"); + assertThat(generation.getMetadata().getFinishReason()).isEqualTo(ToolExecutionResult.FINISH_REASON); + + // Verify new source metadata is added + assertThat((String) generation.getMetadata().get(ToolExecutionResult.METADATA_SOURCE)) + .isEqualTo(ToolExecutionResult.SOURCE_TOOL); + } + }