From 74ed669677ca81ac09e8f61574648a806d7ae7be Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Wed, 19 Mar 2025 11:02:37 -0700 Subject: [PATCH] ollama: Adopt new strategy for ObservationContext Relates to gh-2518 Signed-off-by: Thomas Vitale --- .../ai/ollama/OllamaEmbeddingModel.java | 47 ++++++++++--------- .../ollama/OllamaEmbeddingRequestTests.java | 39 ++++++++------- 2 files changed, 47 insertions(+), 39 deletions(-) diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java index b3b64cd1339..da0408782e6 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java @@ -31,7 +31,6 @@ import org.springframework.ai.embedding.Embedding; import org.springframework.ai.embedding.EmbeddingModel; import org.springframework.ai.embedding.EmbeddingOptions; -import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; @@ -105,13 +104,16 @@ public float[] embed(Document document) { public EmbeddingResponse call(EmbeddingRequest request) { Assert.notEmpty(request.getInstructions(), "At least one text is required!"); - OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest = ollamaEmbeddingRequest(request.getInstructions(), - request.getOptions()); + // Before moving any further, build the final request EmbeddingRequest, + // merging runtime and default options. + EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request); + + OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest = ollamaEmbeddingRequest(embeddingRequest); var observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(request) .provider(OllamaApi.PROVIDER_NAME) - .requestOptions(buildRequestOptions(ollamaEmbeddingRequest)) + .requestOptions(embeddingRequest.getOptions()) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION @@ -142,31 +144,34 @@ private DefaultUsage getDefaultUsage(OllamaApi.EmbeddingsResponse response) { return new DefaultUsage(Optional.ofNullable(response.promptEvalCount()).orElse(0), 0); } - /** - * Package access for testing. - */ - OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List inputContent, EmbeddingOptions options) { - - // runtime options + EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) { + // Process runtime options OllamaOptions runtimeOptions = null; - if (options != null && options instanceof OllamaOptions ollamaOptions) { - runtimeOptions = ollamaOptions; + if (embeddingRequest.getOptions() != null) { + runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class, + OllamaOptions.class); } - OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class); + // Define request options by merging runtime options and default options + OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + OllamaOptions.class); - // Override the model. - if (!StringUtils.hasText(mergedOptions.getModel())) { - throw new IllegalArgumentException("Model is not set!"); + // Validate request options + if (!StringUtils.hasText(requestOptions.getModel())) { + throw new IllegalArgumentException("model cannot be null or empty"); } - String model = mergedOptions.getModel(); - return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()), - OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate()); + return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions); } - private EmbeddingOptions buildRequestOptions(OllamaApi.EmbeddingsRequest request) { - return EmbeddingOptionsBuilder.builder().withModel(request.model()).build(); + /** + * Package access for testing. + */ + OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(EmbeddingRequest embeddingRequest) { + OllamaOptions requestOptions = (OllamaOptions) embeddingRequest.getOptions(); + return new OllamaApi.EmbeddingsRequest(requestOptions.getModel(), embeddingRequest.getInstructions(), + DurationParser.parse(requestOptions.getKeepAlive()), + OllamaOptions.filterNonSupportedFields(requestOptions.toMap()), requestOptions.getTruncate()); } /** diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java index 9cc260a72c0..c1a52989fd5 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; @@ -40,19 +41,18 @@ public class OllamaEmbeddingRequestTests { @Test public void ollamaEmbeddingRequestDefaultOptions() { - - var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), null); - - assertThat(request.model()).isEqualTo("DEFAULT_MODEL"); - assertThat(request.options().get("num_gpu")).isEqualTo(1); - assertThat(request.options().get("main_gpu")).isEqualTo(11); - assertThat(request.options().get("use_mmap")).isEqualTo(true); - assertThat(request.input()).isEqualTo(List.of("Hello")); + var embeddingRequest = this.embeddingModel.buildEmbeddingRequest(new EmbeddingRequest(List.of("Hello"), null)); + var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest); + + assertThat(ollamaRequest.model()).isEqualTo("DEFAULT_MODEL"); + assertThat(ollamaRequest.options().get("num_gpu")).isEqualTo(1); + assertThat(ollamaRequest.options().get("main_gpu")).isEqualTo(11); + assertThat(ollamaRequest.options().get("use_mmap")).isEqualTo(true); + assertThat(ollamaRequest.input()).isEqualTo(List.of("Hello")); } @Test public void ollamaEmbeddingRequestRequestOptions() { - var promptOptions = OllamaOptions.builder()// .model("PROMPT_MODEL")// .mainGPU(22)// @@ -60,23 +60,26 @@ public void ollamaEmbeddingRequestRequestOptions() { .numGPU(2) .build(); - var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), promptOptions); + var embeddingRequest = this.embeddingModel + .buildEmbeddingRequest(new EmbeddingRequest(List.of("Hello"), promptOptions)); + var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest); - assertThat(request.model()).isEqualTo("PROMPT_MODEL"); - assertThat(request.options().get("num_gpu")).isEqualTo(2); - assertThat(request.options().get("main_gpu")).isEqualTo(22); - assertThat(request.options().get("use_mmap")).isEqualTo(true); - assertThat(request.input()).isEqualTo(List.of("Hello")); + assertThat(ollamaRequest.model()).isEqualTo("PROMPT_MODEL"); + assertThat(ollamaRequest.options().get("num_gpu")).isEqualTo(2); + assertThat(ollamaRequest.options().get("main_gpu")).isEqualTo(22); + assertThat(ollamaRequest.options().get("use_mmap")).isEqualTo(true); + assertThat(ollamaRequest.input()).isEqualTo(List.of("Hello")); } @Test public void ollamaEmbeddingRequestWithNegativeKeepAlive() { - var promptOptions = OllamaOptions.builder().model("PROMPT_MODEL").keepAlive("-1m").build(); - var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), promptOptions); + var embeddingRequest = this.embeddingModel + .buildEmbeddingRequest(new EmbeddingRequest(List.of("Hello"), promptOptions)); + var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest); - assertThat(request.keepAlive()).isEqualTo(Duration.ofMinutes(-1)); + assertThat(ollamaRequest.keepAlive()).isEqualTo(Duration.ofMinutes(-1)); } }