Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall;
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction;
import org.springframework.ai.ollama.api.OllamaModelPuller;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -92,6 +93,8 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

private final OllamaModelPuller modelPuller;

public OllamaChatModel(OllamaApi ollamaApi) {
this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
}
Expand Down Expand Up @@ -120,10 +123,12 @@ public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions,
this.chatApi = chatApi;
this.defaultOptions = defaultOptions;
this.observationRegistry = observationRegistry;
this.modelPuller = new OllamaModelPuller(chatApi);
}

@Override
public ChatResponse call(Prompt prompt) {

OllamaApi.ChatRequest request = ollamaChatRequest(prompt, false);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Expand Down Expand Up @@ -319,6 +324,11 @@ else if (message instanceof ToolResponseMessage toolMessage) {
}
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);

mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel());
if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) {
mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel());
}

// Override the model.
if (!StringUtils.hasText(mergedOptions.getModel())) {
throw new IllegalArgumentException("Model is not set!");
Expand All @@ -343,6 +353,10 @@ else if (message instanceof ToolResponseMessage toolMessage) {
requestBuilder.withTools(this.getFunctionTools(functionsForThisRequest));
}

if (mergedOptions.isPullMissingModel()) {
this.modelPuller.pullModel(mergedOptions.getModel(), true);
}

return requestBuilder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
import org.springframework.ai.ollama.api.OllamaModelPuller;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.metadata.OllamaEmbeddingUsage;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -75,6 +76,8 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel {
*/
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

private final OllamaModelPuller modelPuller;

public OllamaEmbeddingModel(OllamaApi ollamaApi) {
this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
}
Expand All @@ -92,6 +95,7 @@ public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
this.ollamaApi = ollamaApi;
this.defaultOptions = defaultOptions;
this.observationRegistry = observationRegistry;
this.modelPuller = new OllamaModelPuller(ollamaApi);
}

@Override
Expand Down Expand Up @@ -149,12 +153,21 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, Em

OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);

mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel());
if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) {
mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel());
}

// Override the model.
if (!StringUtils.hasText(mergedOptions.getModel())) {
throw new IllegalArgumentException("Model is not set!");
}
String model = mergedOptions.getModel();

if (mergedOptions.isPullMissingModel()) {
this.modelPuller.pullModel(model, true);
}

return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()),
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright 2024 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.ollama.api;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.ollama.api.OllamaApi.DeleteModelRequest;
import org.springframework.ai.ollama.api.OllamaApi.ListModelResponse;
import org.springframework.ai.ollama.api.OllamaApi.PullModelRequest;
import org.springframework.http.HttpStatus;
import org.springframework.util.CollectionUtils;

/**
* Helper class that allow to check if a model is available locally and pull it if not.
*
* @author Christian Tzolov
* @since 1.0.0
*/
public class OllamaModelPuller {

private final Logger logger = LoggerFactory.getLogger(OllamaModelPuller.class);

private OllamaApi ollamaApi;

public OllamaModelPuller(OllamaApi ollamaApi) {
this.ollamaApi = ollamaApi;
}

public boolean isModelAvailable(String modelName) {
ListModelResponse modelsResponse = ollamaApi.listModels();
if (!CollectionUtils.isEmpty(modelsResponse.models())) {
return modelsResponse.models().stream().anyMatch(m -> m.name().equals(modelName));
}
return false;
}

public boolean deleteModel(String modelName) {
logger.info("Delete model: {}", modelName);
if (!isModelAvailable(modelName)) {
logger.info("Model: {} not found!", modelName);
return false;
}
return this.ollamaApi.deleteModel(new DeleteModelRequest(modelName)).getStatusCode().equals(HttpStatus.OK);
}

public String pullModel(String modelName, boolean reTry) {
String status = "";
do {
logger.info("Start Pulling model: {}", modelName);
var progress = this.ollamaApi.pullModel(new PullModelRequest(modelName));
status = progress.status();
logger.info("Pulling model: {} - Status: {}", modelName, status);
try {
Thread.sleep(5000);
}
catch (InterruptedException e) {
e.printStackTrace();
}
}
while (reTry && !status.equals("success"));
return status;
}

public static void main(String[] args) {

var utils = new OllamaModelPuller(new OllamaApi());

System.out.println(utils.isModelAvailable("orca-mini:latest"));

String model = "hf.co/bartowski/Llama-3.2-3B-Instruct-GGUF:Q8_0";

if (!utils.isModelAvailable(model)) {
utils.pullModel(model, true);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,9 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
@JsonIgnore
private Map<String, Object> toolContext;

@JsonIgnore
private boolean pullMissingModel;

public static OllamaOptions builder() {
return new OllamaOptions();
}
Expand Down Expand Up @@ -516,6 +519,11 @@ public OllamaOptions withToolContext(Map<String, Object> toolContext) {
return this;
}

public OllamaOptions withPullMissingModel(boolean pullMissingModel) {
this.pullMissingModel = pullMissingModel;
return this;
}

// -------------------
// Getters and Setters
// -------------------
Expand Down Expand Up @@ -856,6 +864,14 @@ public void setToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
}

public Boolean isPullMissingModel() {
return this.pullMissingModel;
}

public void setPullMissingModel(boolean pullMissingModel) {
this.pullMissingModel = pullMissingModel;
}

/**
* Convert the {@link OllamaOptions} object to a {@link Map} of key/value pairs.
* @return The {@link Map} of key/value pairs.
Expand Down Expand Up @@ -926,7 +942,8 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
.withFunctions(fromOptions.getFunctions())
.withProxyToolCalls(fromOptions.getProxyToolCalls())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withToolContext(fromOptions.getToolContext());
.withToolContext(fromOptions.getToolContext())
.withPullMissingModel(fromOptions.isPullMissingModel());
}
// @formatter:on

Expand Down Expand Up @@ -956,7 +973,8 @@ public boolean equals(Object o) {
&& Objects.equals(penalizeNewline, that.penalizeNewline) && Objects.equals(stop, that.stop)
&& Objects.equals(functionCallbacks, that.functionCallbacks)
&& Objects.equals(proxyToolCalls, that.proxyToolCalls) && Objects.equals(functions, that.functions)
&& Objects.equals(toolContext, that.toolContext);
&& Objects.equals(toolContext, that.toolContext)
&& Objects.equals(pullMissingModel, that.pullMissingModel);
}

@Override
Expand All @@ -967,7 +985,7 @@ public int hashCode() {
this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty,
this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta,
this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls,
this.toolContext);
this.toolContext, this.pullMissingModel);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public class BaseOllamaIT {
private static final Logger logger = LoggerFactory.getLogger(BaseOllamaIT.class);

// Toggle for running tests locally on native Ollama for a faster feedback loop.
private static final boolean useTestcontainers = true;
private static final boolean useTestcontainers = false;

public static final OllamaContainer ollamaContainer;

Expand All @@ -30,7 +30,7 @@ public class BaseOllamaIT {
* to the file ".testcontainers.properties" located in your home directory
*/
public static boolean isDisabled() {
return true;
return false;
}

public static OllamaApi buildOllamaApiWithModel(String model) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledIf;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
Expand All @@ -32,6 +33,7 @@
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaModelPuller;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
Expand All @@ -56,6 +58,26 @@ class OllamaChatModelIT extends BaseOllamaIT {
@Autowired
private OllamaChatModel chatModel;

@Autowired
private OllamaApi ollamaApi;

@Test
void autoPullModelTest() {
var puller = new OllamaModelPuller(ollamaApi);
puller.deleteModel("tinyllama");

assertThat(puller.isModelAvailable("tinyllama")).isFalse();

String joke = ChatClient.create(chatModel)
.prompt("Tell me a joke")
.options(OllamaOptions.builder().withModel("tinyllama").withPullMissingModel(true).build())
.call()
.content();

assertThat(joke).isNotEmpty();
assertThat(puller.isModelAvailable("tinyllamaf")).isFalse();
}

@Test
void roleTest() {
Message systemMessage = new SystemPromptTemplate("""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.DeleteModelRequest;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaModelPuller;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
Expand All @@ -42,6 +44,9 @@ class OllamaEmbeddingModelIT extends BaseOllamaIT {
@Autowired
private OllamaEmbeddingModel embeddingModel;

@Autowired
private OllamaApi ollamaApi;

@Test
void embeddings() {
assertThat(embeddingModel).isNotNull();
Expand All @@ -59,6 +64,37 @@ void embeddings() {
assertThat(embeddingModel.dimensions()).isEqualTo(768);
}

@Test
void autoPullModel() {
assertThat(embeddingModel).isNotNull();

var puller = new OllamaModelPuller(ollamaApi);
puller.deleteModel("all-minilm:latest");

assertThat(puller.isModelAvailable("all-minilm")).isFalse();

EmbeddingResponse embeddingResponse = embeddingModel
.call(new EmbeddingRequest(List.of("Hello World", "Something else"),
OllamaOptions.builder()
.withModel("all-minilm:latest")
.withPullMissingModel(true)
.withTruncate(false)
.build()));

assertThat(puller.isModelAvailable("all-minilm:latest")).isTrue();

assertThat(embeddingResponse.getResults()).hasSize(2);
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("all-minilm:latest");
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4);
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4);

assertThat(embeddingModel.dimensions()).isEqualTo(768);
}

@SpringBootConfiguration
public static class TestConfiguration {

Expand Down
Loading
Loading