Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import io.micrometer.observation.ObservationRegistry;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

Expand All @@ -38,6 +40,7 @@
import org.springframework.ai.retry.RetryUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.*;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No wildcard imports.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it was before. Fixed.


/**
Expand Down Expand Up @@ -171,4 +174,153 @@ void buildChatResponseMetadataAggregationWithNonEmptyMetadataButEmptyEval() {

}

@Test
void buildOllamaChatModelWithNullOllamaApi() {
assertThatThrownBy(() -> OllamaChatModel.builder().ollamaApi(null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("ollamaApi must not be null");
}

@Test
void buildOllamaChatModelWithAllBuilderOptions() {
OllamaOptions options = OllamaOptions.builder().model(OllamaModel.CODELLAMA).temperature(0.7).topK(50).build();

ToolCallingManager toolManager = ToolCallingManager.builder().build();
ModelManagementOptions managementOptions = ModelManagementOptions.builder().build();

ChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(this.ollamaApi)
.defaultOptions(options)
.toolCallingManager(toolManager)
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.observationRegistry(ObservationRegistry.NOOP)
.modelManagementOptions(managementOptions)
.build();

assertThat(chatModel).isNotNull();
assertThat(chatModel).isInstanceOf(OllamaChatModel.class);
}

@Test
void buildChatResponseMetadataWithLargeValues() {
Long evalDuration = Long.MAX_VALUE;
Integer evalCount = Integer.MAX_VALUE;
Integer promptEvalCount = Integer.MAX_VALUE;
Long promptEvalDuration = Long.MAX_VALUE;

OllamaApi.ChatResponse response = new OllamaApi.ChatResponse("model", Instant.now(), null, null, null,
Long.MAX_VALUE, Long.MAX_VALUE, promptEvalCount, promptEvalDuration, evalCount, evalDuration);

ChatResponseMetadata metadata = OllamaChatModel.from(response, null);

assertEquals(Duration.ofNanos(evalDuration), metadata.get("eval-duration"));
assertEquals(evalCount, metadata.get("eval-count"));
assertEquals(Duration.ofNanos(promptEvalDuration), metadata.get("prompt-eval-duration"));
assertEquals(promptEvalCount, metadata.get("prompt-eval-count"));
}

@Test
void buildChatResponseMetadataAggregationWithNullPrevious() {
Long evalDuration = 1000L;
Integer evalCount = 101;
Integer promptEvalCount = 808;
Long promptEvalDuration = 8L;

OllamaApi.ChatResponse response = new OllamaApi.ChatResponse("model", Instant.now(), null, null, null, 2000L,
100L, promptEvalCount, promptEvalDuration, evalCount, evalDuration);

ChatResponseMetadata metadata = OllamaChatModel.from(response, null);

assertThat(metadata.getUsage()).isEqualTo(new DefaultUsage(promptEvalCount, evalCount));
assertEquals(Duration.ofNanos(evalDuration), metadata.get("eval-duration"));
assertEquals(evalCount, metadata.get("eval-count"));
assertEquals(Duration.ofNanos(promptEvalDuration), metadata.get("prompt-eval-duration"));
assertEquals(promptEvalCount, metadata.get("prompt-eval-count"));
}

@ParameterizedTest
@ValueSource(strings = { "LLAMA2", "MISTRAL", "CODELLAMA", "LLAMA3", "GEMMA" })
void buildOllamaChatModelWithDifferentModels(String modelName) {
OllamaModel model = OllamaModel.valueOf(modelName);
OllamaOptions options = OllamaOptions.builder().model(model).build();

ChatModel chatModel = OllamaChatModel.builder().ollamaApi(this.ollamaApi).defaultOptions(options).build();

assertThat(chatModel).isNotNull();
assertThat(chatModel).isInstanceOf(OllamaChatModel.class);
}

@Test
void buildOllamaChatModelWithCustomObservationRegistry() {
ObservationRegistry customRegistry = ObservationRegistry.create();

ChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(this.ollamaApi)
.observationRegistry(customRegistry)
.build();

assertThat(chatModel).isNotNull();
}

@Test
void buildChatResponseMetadataPreservesModelName() {
String modelName = "custom-model-name";
OllamaApi.ChatResponse response = new OllamaApi.ChatResponse(modelName, Instant.now(), null, null, null, 1000L,
100L, 10, 50L, 20, 200L);

ChatResponseMetadata metadata = OllamaChatModel.from(response, null);

// Verify that model information is preserved in metadata
assertThat(metadata).isNotNull();
// Note: The exact key for model name would depend on the implementation
// This test verifies that metadata building doesn't lose model information
}

@Test
void buildChatResponseMetadataWithInstantTime() {
Instant createdAt = Instant.now();
OllamaApi.ChatResponse response = new OllamaApi.ChatResponse("model", createdAt, null, null, null, 1000L, 100L,
10, 50L, 20, 200L);

ChatResponseMetadata metadata = OllamaChatModel.from(response, null);

assertThat(metadata).isNotNull();
// Verify timestamp is preserved (exact key depends on implementation)
}

@Test
void buildChatResponseMetadataAggregationOverflowHandling() {
// Test potential integer overflow scenarios
OllamaApi.ChatResponse response = new OllamaApi.ChatResponse("model", Instant.now(), null, null, null, 1000L,
100L, Integer.MAX_VALUE, Long.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE);

ChatResponse previousChatResponse = ChatResponse.builder()
.generations(List.of())
.metadata(ChatResponseMetadata.builder()
.usage(new DefaultUsage(1, 1))
.keyValue("eval-duration", Duration.ofNanos(1L))
.keyValue("prompt-eval-duration", Duration.ofNanos(1L))
.build())
.build();

// This should not throw an exception, even with potential overflow
ChatResponseMetadata metadata = OllamaChatModel.from(response, previousChatResponse);
assertThat(metadata).isNotNull();
}

@Test
void buildOllamaChatModelImmutability() {
// Test that the builder creates immutable instances
OllamaOptions options = OllamaOptions.builder().model(OllamaModel.MISTRAL).temperature(0.5).build();

ChatModel chatModel1 = OllamaChatModel.builder().ollamaApi(this.ollamaApi).defaultOptions(options).build();

ChatModel chatModel2 = OllamaChatModel.builder().ollamaApi(this.ollamaApi).defaultOptions(options).build();

// Should create different instances
assertThat(chatModel1).isNotSameAs(chatModel2);
assertThat(chatModel1).isNotNull();
assertThat(chatModel2).isNotNull();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
package org.springframework.ai.ollama;

import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;

import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.ollama.api.OllamaApi;
Expand All @@ -34,10 +37,15 @@
*/
public class OllamaEmbeddingRequestTests {

OllamaEmbeddingModel embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(OllamaApi.builder().build())
.defaultOptions(OllamaOptions.builder().model("DEFAULT_MODEL").mainGPU(11).useMMap(true).numGPU(1).build())
.build();
private OllamaEmbeddingModel embeddingModel;

@BeforeEach
public void setUp() {
embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(OllamaApi.builder().build())
.defaultOptions(OllamaOptions.builder().model("DEFAULT_MODEL").mainGPU(11).useMMap(true).numGPU(1).build())
.build();
}

@Test
public void ollamaEmbeddingRequestDefaultOptions() {
Expand Down Expand Up @@ -82,4 +90,99 @@ public void ollamaEmbeddingRequestWithNegativeKeepAlive() {
assertThat(ollamaRequest.keepAlive()).isEqualTo(Duration.ofMinutes(-1));
}

@Test
public void ollamaEmbeddingRequestWithEmptyInput() {
var embeddingRequest = this.embeddingModel
.buildEmbeddingRequest(new EmbeddingRequest(Collections.emptyList(), null));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);

assertThat(ollamaRequest.input()).isEmpty();
assertThat(ollamaRequest.model()).isEqualTo("DEFAULT_MODEL");
}

@Test
public void ollamaEmbeddingRequestWithMultipleInputs() {
List<String> inputs = Arrays.asList("Hello", "World", "How are you?");
var embeddingRequest = this.embeddingModel.buildEmbeddingRequest(new EmbeddingRequest(inputs, null));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);

assertThat(ollamaRequest.input()).hasSize(3);
assertThat(ollamaRequest.input()).containsExactly("Hello", "World", "How are you?");
}

@Test
public void ollamaEmbeddingRequestOptionsOverrideDefaults() {
var requestOptions = OllamaOptions.builder()
.model("OVERRIDE_MODEL")
.mainGPU(99)
.useMMap(false)
.numGPU(8)
.build();

var embeddingRequest = this.embeddingModel
.buildEmbeddingRequest(new EmbeddingRequest(List.of("Override test"), requestOptions));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);

// Request options should override defaults
assertThat(ollamaRequest.model()).isEqualTo("OVERRIDE_MODEL");
assertThat(ollamaRequest.options().get("num_gpu")).isEqualTo(8);
assertThat(ollamaRequest.options().get("main_gpu")).isEqualTo(99);
assertThat(ollamaRequest.options().get("use_mmap")).isEqualTo(false);
}

@Test
public void ollamaEmbeddingRequestWithDifferentKeepAliveFormats() {
// Test seconds format
var optionsSeconds = OllamaOptions.builder().keepAlive("30s").build();
var requestSeconds = this.embeddingModel
.buildEmbeddingRequest(new EmbeddingRequest(List.of("Test"), optionsSeconds));
var ollamaRequestSeconds = this.embeddingModel.ollamaEmbeddingRequest(requestSeconds);
assertThat(ollamaRequestSeconds.keepAlive()).isEqualTo(Duration.ofSeconds(30));

// Test hours format
var optionsHours = OllamaOptions.builder().keepAlive("2h").build();
var requestHours = this.embeddingModel
.buildEmbeddingRequest(new EmbeddingRequest(List.of("Test"), optionsHours));
var ollamaRequestHours = this.embeddingModel.ollamaEmbeddingRequest(requestHours);
assertThat(ollamaRequestHours.keepAlive()).isEqualTo(Duration.ofHours(2));
}

@Test
public void ollamaEmbeddingRequestWithMinimalDefaults() {
// Create model with minimal defaults
var minimalModel = OllamaEmbeddingModel.builder()
.ollamaApi(OllamaApi.builder().build())
.defaultOptions(OllamaOptions.builder().model("MINIMAL_MODEL").build())
.build();

var embeddingRequest = minimalModel.buildEmbeddingRequest(new EmbeddingRequest(List.of("Minimal test"), null));
var ollamaRequest = minimalModel.ollamaEmbeddingRequest(embeddingRequest);

assertThat(ollamaRequest.model()).isEqualTo("MINIMAL_MODEL");
assertThat(ollamaRequest.input()).isEqualTo(List.of("Minimal test"));
// Should not have GPU-related options when not set
assertThat(ollamaRequest.options().get("num_gpu")).isNull();
assertThat(ollamaRequest.options().get("main_gpu")).isNull();
assertThat(ollamaRequest.options().get("use_mmap")).isNull();
}

@Test
public void ollamaEmbeddingRequestPreservesInputOrder() {
List<String> orderedInputs = Arrays.asList("First", "Second", "Third", "Fourth");
var embeddingRequest = this.embeddingModel.buildEmbeddingRequest(new EmbeddingRequest(orderedInputs, null));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);

assertThat(ollamaRequest.input()).containsExactly("First", "Second", "Third", "Fourth");
}

@Test
public void ollamaEmbeddingRequestWithWhitespaceInputs() {
List<String> inputs = Arrays.asList("", " ", "\t\n", "normal text", " spaced ");
var embeddingRequest = this.embeddingModel.buildEmbeddingRequest(new EmbeddingRequest(inputs, null));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);

// Verify that whitespace inputs are preserved as-is
assertThat(ollamaRequest.input()).containsExactly("", " ", "\t\n", "normal text", " spaced ");
}

}