Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall;
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction;
import org.springframework.ai.ollama.api.OllamaModelPuller;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
Expand All @@ -59,10 +61,9 @@
/**
* {@link ChatModel} implementation for {@literal Ollama}. Ollama allows developers to run
* large language models and generate embeddings locally. It supports open-source models
* available on [Ollama AI Library](<a href="https://ollama.ai/library">...</a>). - Llama
* 2 (7B parameters, 3.8GB size) - Mistral (7B parameters, 4.1GB size) Please refer to the
* <a href="https://ollama.ai/">official Ollama website</a> for the most up-to-date
* information on available models.
* available on [Ollama AI Library](<a href="https://ollama.ai/library">...</a>) and on
* Hugging Face. Please refer to the <a href="https://ollama.ai/">official Ollama
* website</a> for the most up-to-date information on available models.
*
* @author Christian Tzolov
* @author luocongqiu
Expand All @@ -73,57 +74,33 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode

private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();

/**
* Low-level Ollama API library.
*/
private final OllamaApi chatApi;

/**
* Default options to be used for all chat requests.
*/
private final OllamaOptions defaultOptions;

/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

private final OllamaModelPuller modelPuller;

public OllamaChatModel(OllamaApi ollamaApi) {
this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
}

public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
this(ollamaApi, defaultOptions, null);
}
private final OllamaModelManager modelManager;

public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
FunctionCallbackContext functionCallbackContext) {
this(ollamaApi, defaultOptions, functionCallbackContext, List.of());
}
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument list grew so much that I didn't want to add even more overloaded constructors. Instead, I introduced a Builder to help making this whole initialisation code more readable.

FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks) {
this(ollamaApi, defaultOptions, functionCallbackContext, toolFunctionCallbacks, ObservationRegistry.NOOP);
}

public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks,
ObservationRegistry observationRegistry) {
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
super(functionCallbackContext, defaultOptions, toolFunctionCallbacks);
Assert.notNull(chatApi, "ollamaApi must not be null");
Assert.notNull(ollamaApi, "ollamaApi must not be null");
Assert.notNull(defaultOptions, "defaultOptions must not be null");
Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
this.chatApi = chatApi;
Assert.notNull(observationRegistry, "observationRegistry must not be null");
Assert.notNull(observationRegistry, "modelManagementOptions must not be null");
this.chatApi = ollamaApi;
this.defaultOptions = defaultOptions;
this.observationRegistry = observationRegistry;
this.modelPuller = new OllamaModelPuller(chatApi);
this.modelManager = new OllamaModelManager(chatApi, modelManagementOptions);
initializeModelIfEnabled(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
}

public static Builder builder() {
return new Builder();
}

@Override
Expand Down Expand Up @@ -324,9 +301,9 @@ else if (message instanceof ToolResponseMessage toolMessage) {
}
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);

mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel());
if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) {
mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel());
mergedOptions.setPullModelStrategy(this.defaultOptions.getPullModelStrategy());
if (runtimeOptions != null && runtimeOptions.getPullModelStrategy() != null) {
mergedOptions.setPullModelStrategy(runtimeOptions.getPullModelStrategy());
}

// Override the model.
Expand All @@ -353,9 +330,7 @@ else if (message instanceof ToolResponseMessage toolMessage) {
requestBuilder.withTools(this.getFunctionTools(functionsForThisRequest));
}

if (mergedOptions.isPullMissingModel()) {
this.modelPuller.pullModel(mergedOptions.getModel(), true);
}
initializeModelIfEnabled(mergedOptions.getModel(), mergedOptions.getPullModelStrategy());

return requestBuilder.build();
}
Expand Down Expand Up @@ -400,6 +375,15 @@ public ChatOptions getDefaultOptions() {
return OllamaOptions.fromOptions(this.defaultOptions);
}

/**
* Pull the given model into Ollama based on the specified strategy.
*/
private void initializeModelIfEnabled(String model, PullModelStrategy pullModelStrategy) {
if (!PullModelStrategy.NEVER.equals(pullModelStrategy)) {
this.modelManager.pullModel(model, pullModelStrategy);
}
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
Expand All @@ -409,4 +393,58 @@ public void setObservationConvention(ChatModelObservationConvention observationC
this.observationConvention = observationConvention;
}

public static class Builder {

private OllamaApi ollamaApi;

private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);

private FunctionCallbackContext functionCallbackContext;

private List<FunctionCallback> toolFunctionCallbacks = List.of();

private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();

private Builder() {
}

public Builder withOllamaApi(OllamaApi ollamaApi) {
this.ollamaApi = ollamaApi;
return this;
}

public Builder withDefaultOptions(OllamaOptions defaultOptions) {
this.defaultOptions = defaultOptions;
return this;
}

public Builder withFunctionCallbackContext(FunctionCallbackContext functionCallbackContext) {
this.functionCallbackContext = functionCallbackContext;
return this;
}

public Builder withToolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
this.toolFunctionCallbacks = toolFunctionCallbacks;
return this;
}

public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
return this;
}

public Builder withModelManagementOptions(ModelManagementOptions modelManagementOptions) {
this.modelManagementOptions = modelManagementOptions;
return this;
}

public OllamaChatModel build() {
return new OllamaChatModel(ollamaApi, defaultOptions, functionCallbackContext, toolFunctionCallbacks,
observationRegistry, modelManagementOptions);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.regex.Pattern;

import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.*;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
Expand All @@ -32,24 +31,20 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
import org.springframework.ai.ollama.api.OllamaModelPuller;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.ollama.metadata.OllamaEmbeddingUsage;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* {@link EmbeddingModel} implementation for {@literal Ollama}.
*
* Ollama allows developers to run large language models and generate embeddings locally.
* It supports open-source models available on [Ollama AI
* Library](https://ollama.ai/library).
*
* Examples of models supported: - Llama 2 (7B parameters, 3.8GB size) - Mistral (7B
* parameters, 4.1GB size)
*
* Please refer to the <a href="https://ollama.ai/">official Ollama website</a> for the
* most up-to-date information on available models.
* {@link EmbeddingModel} implementation for {@literal Ollama}. Ollama allows developers
* to run large language models and generate embeddings locally. It supports open-source
* models available on [Ollama AI Library](<a href="https://ollama.ai/library">...</a>)
* and on Hugging Face. Please refer to the <a href="https://ollama.ai/">official Ollama
* website</a> for the most up-to-date information on available models.
*
* @author Christian Tzolov
* @author Thomas Vitale
Expand All @@ -61,41 +56,31 @@ public class OllamaEmbeddingModel extends AbstractEmbeddingModel {

private final OllamaApi ollamaApi;

/**
* Default options to be used for all chat requests.
*/
private final OllamaOptions defaultOptions;

/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

private final OllamaModelPuller modelPuller;
private final OllamaModelManager modelManager;

public OllamaEmbeddingModel(OllamaApi ollamaApi) {
this(ollamaApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
}

public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
this(ollamaApi, defaultOptions, ObservationRegistry.NOOP);
}
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here, I introduced a Builder

ObservationRegistry observationRegistry) {
Assert.notNull(ollamaApi, "openAiApi must not be null");
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
Assert.notNull(ollamaApi, "ollamaApi must not be null");
Assert.notNull(defaultOptions, "options must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");

this.ollamaApi = ollamaApi;
this.defaultOptions = defaultOptions;
this.observationRegistry = observationRegistry;
this.modelPuller = new OllamaModelPuller(ollamaApi);
this.modelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);

initializeModelIfEnabled(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
}

public static Builder builder() {
return new Builder();
}

@Override
Expand Down Expand Up @@ -153,9 +138,9 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, Em

OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);

mergedOptions.setPullMissingModel(this.defaultOptions.isPullMissingModel());
if (runtimeOptions != null && runtimeOptions.isPullMissingModel() != null) {
mergedOptions.setPullMissingModel(runtimeOptions.isPullMissingModel());
mergedOptions.setPullModelStrategy(this.defaultOptions.getPullModelStrategy());
if (runtimeOptions != null && runtimeOptions.getPullModelStrategy() != null) {
mergedOptions.setPullModelStrategy(runtimeOptions.getPullModelStrategy());
}

// Override the model.
Expand All @@ -164,9 +149,7 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, Em
}
String model = mergedOptions.getModel();

if (mergedOptions.isPullMissingModel()) {
this.modelPuller.pullModel(model, true);
}
initializeModelIfEnabled(mergedOptions.getModel(), mergedOptions.getPullModelStrategy());

return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()),
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
Expand All @@ -176,6 +159,15 @@ private EmbeddingOptions buildRequestOptions(OllamaApi.EmbeddingsRequest request
return EmbeddingOptionsBuilder.builder().withModel(request.model()).build();
}

/**
* Pull the given model into Ollama based on the specified strategy.
*/
private void initializeModelIfEnabled(String model, PullModelStrategy pullModelStrategy) {
if (!PullModelStrategy.NEVER.equals(pullModelStrategy)) {
this.modelManager.pullModel(model, pullModelStrategy);
}
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
Expand Down Expand Up @@ -216,4 +208,43 @@ public static Duration parse(String input) {

}

public static class Builder {

private OllamaApi ollamaApi;

private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);

private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();

private Builder() {
}

public Builder withOllamaApi(OllamaApi ollamaApi) {
this.ollamaApi = ollamaApi;
return this;
}

public Builder withDefaultOptions(OllamaOptions defaultOptions) {
this.defaultOptions = defaultOptions;
return this;
}

public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
return this;
}

public Builder withModelManagementOptions(ModelManagementOptions modelManagementOptions) {
this.modelManagementOptions = modelManagementOptions;
return this;
}

public OllamaEmbeddingModel build() {
return new OllamaEmbeddingModel(ollamaApi, defaultOptions, observationRegistry, modelManagementOptions);
}

}

}
Loading