From 7a72a72e35249b812a7394df29ef269c38beccca Mon Sep 17 00:00:00 2001 From: jitokim Date: Fri, 8 Nov 2024 02:52:06 +0900 Subject: [PATCH 1/2] Fix duplicate keys in ChatResponseMetadata for OllamaApi.ChatResponse metadata Fix assertion in contructor of OllamaChatModel Signed-off-by: jitokim --- .../org/springframework/ai/ollama/OllamaChatModel.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index f4fcd722f15..eae56f6206c 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -74,6 +74,7 @@ * @author Christian Tzolov * @author luocongqiu * @author Thomas Vitale + * @author Jihoon Kim * @since 1.0.0 */ public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel { @@ -97,7 +98,7 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, Assert.notNull(ollamaApi, "ollamaApi must not be null"); Assert.notNull(defaultOptions, "defaultOptions must not be null"); Assert.notNull(observationRegistry, "observationRegistry must not be null"); - Assert.notNull(observationRegistry, "modelManagementOptions must not be null"); + Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null"); this.chatApi = ollamaApi; this.defaultOptions = defaultOptions; this.observationRegistry = observationRegistry; @@ -118,8 +119,8 @@ public static ChatResponseMetadata from(OllamaApi.ChatResponse response) { .withKeyValue("eval-duration", response.evalDuration()) .withKeyValue("eval-count", response.evalCount()) .withKeyValue("load-duration", response.loadDuration()) - .withKeyValue("eval-duration", response.promptEvalDuration()) - .withKeyValue("eval-count", response.promptEvalCount()) + .withKeyValue("prompt-eval-duration", response.promptEvalDuration()) + .withKeyValue("prompt-eval-count", response.promptEvalCount()) .withKeyValue("total-duration", response.totalDuration()) .withKeyValue("done", response.done()) .build(); From 954e1f20e5804fd99d0c5131b84118f6e5ce7f17 Mon Sep 17 00:00:00 2001 From: jitokim Date: Thu, 14 Nov 2024 22:25:24 +0900 Subject: [PATCH 2/2] add OllamaChatModelTests Signed-off-by: jitokim --- .../ai/ollama/OllamaChatModelTests.java | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java new file mode 100644 index 00000000000..706a9d4af1b --- /dev/null +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java @@ -0,0 +1,80 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.ollama; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.BDDMockito.given; + +import java.time.Duration; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; + +/** + * @author Jihoon Kim + * @since 1.0.0 + */ +@ExtendWith(MockitoExtension.class) +public class OllamaChatModelTests { + + @Mock + OllamaApi ollamaApi; + + @Mock + OllamaApi.ChatResponse response; + + @Test + public void buildOllamaChatModel() { + Exception exception = assertThrows(IllegalArgumentException.class, + () -> OllamaChatModel.builder() + .withOllamaApi(ollamaApi) + .withDefaultOptions(OllamaOptions.create().withModel(OllamaModel.LLAMA2)) + .withModelManagementOptions(null) + .build()); + assertEquals("modelManagementOptions must not be null", exception.getMessage()); + } + + @Test + public void buildChatResponseMetadata() { + Duration evalDuration = Duration.ofSeconds(1); + Integer evalCount = 101; + + Duration promptEvalDuration = Duration.ofSeconds(8); + Integer promptEvalCount = 808; + + given(response.evalDuration()).willReturn(evalDuration); + given(response.evalCount()).willReturn(evalCount); + given(response.promptEvalDuration()).willReturn(promptEvalDuration); + given(response.promptEvalCount()).willReturn(promptEvalCount); + + ChatResponseMetadata metadata = OllamaChatModel.from(response); + + assertEquals(evalDuration, metadata.get("eval-duration")); + assertEquals(evalCount, metadata.get("eval-count")); + assertEquals(promptEvalDuration, metadata.get("prompt-eval-duration")); + assertEquals(promptEvalCount, metadata.get("prompt-eval-count")); + } + +}