Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.ai.ollama.api.OllamaOptions;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.BDDMockito.given;

/**
Expand Down Expand Up @@ -115,4 +116,143 @@ public void options() {

}

@Test
public void singleInputEmbedding() {
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
.willReturn(new EmbeddingsResponse("TEST_MODEL", List.of(new float[] { 0.1f, 0.2f, 0.3f }), 10L, 5L, 1));

var embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(this.ollamaApi)
.defaultOptions(OllamaOptions.builder().model("TEST_MODEL").build())
.build();

EmbeddingResponse response = embeddingModel
.call(new EmbeddingRequest(List.of("Single input text"), EmbeddingOptionsBuilder.builder().build()));

assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[] { 0.1f, 0.2f, 0.3f });
assertThat(response.getMetadata().getModel()).isEqualTo("TEST_MODEL");

assertThat(this.embeddingsRequestCaptor.getValue().input()).isEqualTo(List.of("Single input text"));
assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("TEST_MODEL");
}

@Test
public void embeddingWithNullOptions() {
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
.willReturn(new EmbeddingsResponse("NULL_OPTIONS_MODEL", List.of(new float[] { 0.5f }), 5L, 2L, 1));

var embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(this.ollamaApi)
.defaultOptions(OllamaOptions.builder().model("NULL_OPTIONS_MODEL").build())
.build();

EmbeddingResponse response = embeddingModel.call(new EmbeddingRequest(List.of("Null options test"), null));

assertThat(response.getResults()).hasSize(1);
assertThat(response.getMetadata().getModel()).isEqualTo("NULL_OPTIONS_MODEL");

assertThat(this.embeddingsRequestCaptor.getValue().model()).isEqualTo("NULL_OPTIONS_MODEL");
assertThat(this.embeddingsRequestCaptor.getValue().options()).isEqualTo(Map.of());
}

@Test
public void embeddingWithMultipleLargeInputs() {
List<String> largeInputs = List.of(
"This is a very long text input that might be used for document embedding scenarios",
"Another substantial piece of text content that could represent a paragraph or section",
"A third lengthy input to test batch processing capabilities of the embedding model");

given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
.willReturn(new EmbeddingsResponse(
"BATCH_MODEL", List.of(new float[] { 0.1f, 0.2f, 0.3f, 0.4f },
new float[] { 0.5f, 0.6f, 0.7f, 0.8f }, new float[] { 0.9f, 1.0f, 1.1f, 1.2f }),
150L, 75L, 3));

var embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(this.ollamaApi)
.defaultOptions(OllamaOptions.builder().model("BATCH_MODEL").build())
.build();

EmbeddingResponse response = embeddingModel
.call(new EmbeddingRequest(largeInputs, EmbeddingOptionsBuilder.builder().build()));

assertThat(response.getResults()).hasSize(3);
assertThat(response.getResults().get(0).getOutput()).hasSize(4);
assertThat(response.getResults().get(1).getOutput()).hasSize(4);
assertThat(response.getResults().get(2).getOutput()).hasSize(4);

assertThat(this.embeddingsRequestCaptor.getValue().input()).isEqualTo(largeInputs);
}

@Test
public void embeddingWithCustomKeepAliveFormats() {
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
.willReturn(new EmbeddingsResponse("KEEPALIVE_MODEL", List.of(new float[] { 1.0f }), 5L, 2L, 1));

var embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(this.ollamaApi)
.defaultOptions(OllamaOptions.builder().model("KEEPALIVE_MODEL").build())
.build();

// Test with seconds format
var secondsOptions = OllamaOptions.builder().model("KEEPALIVE_MODEL").keepAlive("300s").build();

embeddingModel.call(new EmbeddingRequest(List.of("Keep alive seconds"), secondsOptions));
assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo(Duration.ofSeconds(300));

// Test with hours format
var hoursOptions = OllamaOptions.builder().model("KEEPALIVE_MODEL").keepAlive("2h").build();

embeddingModel.call(new EmbeddingRequest(List.of("Keep alive hours"), hoursOptions));
assertThat(this.embeddingsRequestCaptor.getValue().keepAlive()).isEqualTo(Duration.ofHours(2));
}

@Test
public void embeddingResponseMetadata() {
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
.willReturn(new EmbeddingsResponse("METADATA_MODEL", List.of(new float[] { 0.1f, 0.2f }), 100L, 50L, 25));

var embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(this.ollamaApi)
.defaultOptions(OllamaOptions.builder().model("METADATA_MODEL").build())
.build();

EmbeddingResponse response = embeddingModel
.call(new EmbeddingRequest(List.of("Metadata test"), EmbeddingOptionsBuilder.builder().build()));

assertThat(response.getMetadata().getModel()).isEqualTo("METADATA_MODEL");
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
}

@Test
public void embeddingWithZeroLengthVectors() {
given(this.ollamaApi.embed(this.embeddingsRequestCaptor.capture()))
.willReturn(new EmbeddingsResponse("ZERO_MODEL", List.of(new float[] {}), 0L, 0L, 1));

var embeddingModel = OllamaEmbeddingModel.builder()
.ollamaApi(this.ollamaApi)
.defaultOptions(OllamaOptions.builder().model("ZERO_MODEL").build())
.build();

EmbeddingResponse response = embeddingModel
.call(new EmbeddingRequest(List.of("Zero length test"), EmbeddingOptionsBuilder.builder().build()));

assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput()).isEmpty();
}

@Test
public void builderValidation() {
// Test that builder requires ollamaApi
assertThatThrownBy(() -> OllamaEmbeddingModel.builder().build()).isInstanceOf(IllegalArgumentException.class);

// Test successful builder with minimal required parameters
var model = OllamaEmbeddingModel.builder().ollamaApi(this.ollamaApi).build();

assertThat(model).isNotNull();
}

}
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
package org.springframework.ai.ollama;

import java.time.Instant;
import java.util.List;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.retry.NonTransientAiException;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.retry.TransientAiException;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.ResourceAccessException;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

/**
Expand Down Expand Up @@ -75,6 +83,101 @@ void ollamaChatTransientError() {
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2);
}

@Test
void ollamaChatSuccessOnFirstAttempt() {
String promptText = "Simple question";
var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(),
OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content("Quick response").build(), null,
true, null, null, null, null, null, null);

when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class))).thenReturn(expectedChatResponse);

var result = this.chatModel.call(new Prompt(promptText));

assertThat(result).isNotNull();
assertThat(result.getResult().getOutput().getText()).isEqualTo("Quick response");
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(0);
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(0);
verify(this.ollamaApi, times(1)).chat(isA(OllamaApi.ChatRequest.class));
}

@Test
void ollamaChatNonTransientErrorShouldNotRetry() {
String promptText = "Invalid request";

when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class)))
.thenThrow(new NonTransientAiException("Model not found"));

assertThatThrownBy(() -> this.chatModel.call(new Prompt(promptText)))
.isInstanceOf(NonTransientAiException.class)
.hasMessage("Model not found");

assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(0);
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(1);
verify(this.ollamaApi, times(1)).chat(isA(OllamaApi.ChatRequest.class));
}

@Test
void ollamaChatWithMultipleMessages() {
List<Message> messages = List.of(new UserMessage("What is AI?"), new UserMessage("Explain machine learning"));
Prompt prompt = new Prompt(messages);

var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(),
OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT)
.content("AI is artificial intelligence...")
.build(),
null, true, null, null, null, null, null, null);

when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class)))
.thenThrow(new TransientAiException("Temporary overload"))
.thenReturn(expectedChatResponse);

var result = this.chatModel.call(prompt);

assertThat(result).isNotNull();
assertThat(result.getResult().getOutput().getText()).isEqualTo("AI is artificial intelligence...");
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1);
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(1);
}

@Test
void ollamaChatWithCustomOptions() {
String promptText = "Custom temperature request";
OllamaOptions customOptions = OllamaOptions.builder().model(MODEL).temperature(0.1).topP(0.9).build();

var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(),
OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content("Deterministic response").build(),
null, true, null, null, null, null, null, null);

when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class)))
.thenThrow(new ResourceAccessException("Connection timeout"))
.thenReturn(expectedChatResponse);

var result = this.chatModel.call(new Prompt(promptText, customOptions));

assertThat(result).isNotNull();
assertThat(result.getResult().getOutput().getText()).isEqualTo("Deterministic response");
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1);
}

@Test
void ollamaChatWithEmptyResponse() {
String promptText = "Edge case request";
var expectedChatResponse = new OllamaApi.ChatResponse("CHAT_COMPLETION_ID", Instant.now(),
OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content("").build(), null, true, null, null,
null, null, null, null);

when(this.ollamaApi.chat(isA(OllamaApi.ChatRequest.class)))
.thenThrow(new TransientAiException("Rate limit exceeded"))
.thenReturn(expectedChatResponse);

var result = this.chatModel.call(new Prompt(promptText));

assertThat(result).isNotNull();
assertThat(result.getResult().getOutput().getText()).isEmpty();
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(1);
}

private static class TestRetryListener implements RetryListener {

int onErrorRetryCount = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,19 @@ void dynamicApiKeyRestClient() throws InterruptedException {
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");
}

@Test
void testBuilderMethodsReturnNewInstances() {
OpenAiModerationApi.Builder builder1 = OpenAiModerationApi.builder();
OpenAiModerationApi.Builder builder2 = builder1.apiKey(TEST_API_KEY);
OpenAiModerationApi.Builder builder3 = builder2.baseUrl(TEST_BASE_URL);

assertThat(builder2).isNotNull();
assertThat(builder3).isNotNull();

OpenAiModerationApi api = builder3.build();
assertThat(api).isNotNull();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,51 @@ public void defaultOptionsTools() {
assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
}

@Test
public void promptOptionsOverrideDefaultOptions() {
var client = new ZhiPuAiChatModel(ZhiPuAiApi.builder().apiKey("TEST").build(),
ZhiPuAiChatOptions.builder().model("DEFAULT_MODEL").temperature(10.0).build());

var request = client.createRequest(new Prompt("Test", ZhiPuAiChatOptions.builder().temperature(90.0).build()),
false);

assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
assertThat(request.temperature()).isEqualTo(90.0);
}

@Test
public void defaultOptionsToolsWithAssertion() {
final String TOOL_FUNCTION_NAME = "CurrentWeather";

var client = new ZhiPuAiChatModel(ZhiPuAiApi.builder().apiKey("TEST").build(),
ZhiPuAiChatOptions.builder()
.model("DEFAULT_MODEL")
.toolCallbacks(List.of(FunctionToolCallback.builder(TOOL_FUNCTION_NAME, new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
.build()))
.build());

var prompt = client.buildRequestPrompt(new Prompt("Test message content"));
var request = client.createRequest(prompt, false);

assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isFalse();
assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
assertThat(request.tools()).hasSize(1);
assertThat(request.tools().get(0).getFunction().getName()).isEqualTo(TOOL_FUNCTION_NAME);
}

@Test
public void createRequestWithStreamingEnabled() {
var client = new ZhiPuAiChatModel(ZhiPuAiApi.builder().apiKey("TEST").build(),
ZhiPuAiChatOptions.builder().model("DEFAULT_MODEL").build());

var prompt = client.buildRequestPrompt(new Prompt("Test streaming"));
var request = client.createRequest(prompt, true);

assertThat(request.stream()).isTrue();
assertThat(request.messages()).hasSize(1);
}

}
Loading