Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
Expand Down Expand Up @@ -105,13 +104,16 @@ public float[] embed(Document document) {
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");

OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest = ollamaEmbeddingRequest(request.getInstructions(),
request.getOptions());
// Before moving any further, build the final request EmbeddingRequest,
// merging runtime and default options.
EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request);

OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest = ollamaEmbeddingRequest(embeddingRequest);

var observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(request)
.provider(OllamaApi.PROVIDER_NAME)
.requestOptions(buildRequestOptions(ollamaEmbeddingRequest))
.requestOptions(embeddingRequest.getOptions())
.build();

return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
Expand Down Expand Up @@ -142,31 +144,34 @@ private DefaultUsage getDefaultUsage(OllamaApi.EmbeddingsResponse response) {
return new DefaultUsage(Optional.ofNullable(response.promptEvalCount()).orElse(0), 0);
}

/**
* Package access for testing.
*/
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, EmbeddingOptions options) {

// runtime options
EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
// Process runtime options
OllamaOptions runtimeOptions = null;
if (options != null && options instanceof OllamaOptions ollamaOptions) {
runtimeOptions = ollamaOptions;
if (embeddingRequest.getOptions() != null) {
runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class,
OllamaOptions.class);
}

OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
// Define request options by merging runtime options and default options
OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
OllamaOptions.class);

// Override the model.
if (!StringUtils.hasText(mergedOptions.getModel())) {
throw new IllegalArgumentException("Model is not set!");
// Validate request options
if (!StringUtils.hasText(requestOptions.getModel())) {
throw new IllegalArgumentException("model cannot be null or empty");
}
String model = mergedOptions.getModel();

return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()),
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions);
}

private EmbeddingOptions buildRequestOptions(OllamaApi.EmbeddingsRequest request) {
return EmbeddingOptionsBuilder.builder().withModel(request.model()).build();
/**
* Package access for testing.
*/
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(EmbeddingRequest embeddingRequest) {
OllamaOptions requestOptions = (OllamaOptions) embeddingRequest.getOptions();
return new OllamaApi.EmbeddingsRequest(requestOptions.getModel(), embeddingRequest.getInstructions(),
DurationParser.parse(requestOptions.getKeepAlive()),
OllamaOptions.filterNonSupportedFields(requestOptions.toMap()), requestOptions.getTruncate());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.junit.jupiter.api.Test;

import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;

Expand All @@ -40,43 +41,45 @@ public class OllamaEmbeddingRequestTests {

@Test
public void ollamaEmbeddingRequestDefaultOptions() {

var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), null);

assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
assertThat(request.options().get("num_gpu")).isEqualTo(1);
assertThat(request.options().get("main_gpu")).isEqualTo(11);
assertThat(request.options().get("use_mmap")).isEqualTo(true);
assertThat(request.input()).isEqualTo(List.of("Hello"));
var embeddingRequest = this.embeddingModel.buildEmbeddingRequest(new EmbeddingRequest(List.of("Hello"), null));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);

assertThat(ollamaRequest.model()).isEqualTo("DEFAULT_MODEL");
assertThat(ollamaRequest.options().get("num_gpu")).isEqualTo(1);
assertThat(ollamaRequest.options().get("main_gpu")).isEqualTo(11);
assertThat(ollamaRequest.options().get("use_mmap")).isEqualTo(true);
assertThat(ollamaRequest.input()).isEqualTo(List.of("Hello"));
}

@Test
public void ollamaEmbeddingRequestRequestOptions() {

var promptOptions = OllamaOptions.builder()//
.model("PROMPT_MODEL")//
.mainGPU(22)//
.useMMap(true)//
.numGPU(2)
.build();

var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), promptOptions);
var embeddingRequest = this.embeddingModel
.buildEmbeddingRequest(new EmbeddingRequest(List.of("Hello"), promptOptions));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);

assertThat(request.model()).isEqualTo("PROMPT_MODEL");
assertThat(request.options().get("num_gpu")).isEqualTo(2);
assertThat(request.options().get("main_gpu")).isEqualTo(22);
assertThat(request.options().get("use_mmap")).isEqualTo(true);
assertThat(request.input()).isEqualTo(List.of("Hello"));
assertThat(ollamaRequest.model()).isEqualTo("PROMPT_MODEL");
assertThat(ollamaRequest.options().get("num_gpu")).isEqualTo(2);
assertThat(ollamaRequest.options().get("main_gpu")).isEqualTo(22);
assertThat(ollamaRequest.options().get("use_mmap")).isEqualTo(true);
assertThat(ollamaRequest.input()).isEqualTo(List.of("Hello"));
}

@Test
public void ollamaEmbeddingRequestWithNegativeKeepAlive() {

var promptOptions = OllamaOptions.builder().model("PROMPT_MODEL").keepAlive("-1m").build();

var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), promptOptions);
var embeddingRequest = this.embeddingModel
.buildEmbeddingRequest(new EmbeddingRequest(List.of("Hello"), promptOptions));
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);

assertThat(request.keepAlive()).isEqualTo(Duration.ofMinutes(-1));
assertThat(ollamaRequest.keepAlive()).isEqualTo(Duration.ofMinutes(-1));
}

}