Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.boot.context.properties.bind.ConstructorBinding;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.Assert;
import org.springframework.util.StreamUtils;
Expand Down Expand Up @@ -807,5 +809,159 @@ public EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) {
.body(EmbeddingResponse.class);
}

// --------------------------------------------------------------------------
// Models
// --------------------------------------------------------------------------

@JsonInclude(Include.NON_NULL)
public record Model(
@JsonProperty("name") String name,
@JsonProperty("model") String model,
@JsonProperty("modified_at") Instant modifiedAt,
@JsonProperty("size") Long size,
@JsonProperty("digest") String digest,
@JsonProperty("details") Details details
) {
@JsonInclude(Include.NON_NULL)
public record Details(
@JsonProperty("parent_model") String parentModel,
@JsonProperty("format") String format,
@JsonProperty("family") String family,
@JsonProperty("families") List<String> families,
@JsonProperty("parameter_size") String parameterSize,
@JsonProperty("quantization_level") String quantizationLevel
) {}
}

@JsonInclude(Include.NON_NULL)
public record ListModelResponse(
@JsonProperty("models") List<Model> models
) {}

/**
* List models that are available locally on the machine where Ollama is running.
*/
public ListModelResponse listModels() {
return this.restClient.get()
.uri("/api/tags")
.retrieve()
.onStatus(this.responseErrorHandler)
.body(ListModelResponse.class);
}

@JsonInclude(Include.NON_NULL)
public record ShowModelRequest(
@JsonProperty("model") String model,
@JsonProperty("system") String system,
@JsonProperty("verbose") Boolean verbose,
@JsonProperty("options") Map<String, Object> options
) {
public ShowModelRequest(String model) {
this(model, null, null, null);
}
}

@JsonInclude(Include.NON_NULL)
public record ShowModelResponse(
@JsonProperty("license") String license,
@JsonProperty("modelfile") String modelfile,
@JsonProperty("parameters") String parameters,
@JsonProperty("template") String template,
@JsonProperty("system") String system,
@JsonProperty("details") Model.Details details,
@JsonProperty("messages") List<Message> messages,
@JsonProperty("model_info") Map<String, Object> modelInfo,
@JsonProperty("projector_info") Map<String, Object> projectorInfo,
@JsonProperty("modified_at") Instant modifiedAt
) {}

/**
* Show information about a model available locally on the machine where Ollama is running.
*/
public ShowModelResponse showModel(ShowModelRequest showModelRequest) {
return this.restClient.post()
.uri("/api/show")
.body(showModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.body(ShowModelResponse.class);
}

@JsonInclude(Include.NON_NULL)
public record CopyModelRequest(
@JsonProperty("source") String source,
@JsonProperty("destination") String destination
) {}

/**
* Copy a model. Creates a model with another name from an existing model.
*/
public ResponseEntity<Void> copyModel(CopyModelRequest copyModelRequest) {
return this.restClient.post()
.uri("/api/copy")
.body(copyModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.toBodilessEntity();
}

@JsonInclude(Include.NON_NULL)
public record DeleteModelRequest(
@JsonProperty("model") String model
) {}

/**
* Delete a model and its data.
*/
public ResponseEntity<Void> deleteModel(DeleteModelRequest deleteModelRequest) {
return this.restClient.method(HttpMethod.DELETE)
.uri("/api/delete")
.body(deleteModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.toBodilessEntity();
}

@JsonInclude(Include.NON_NULL)
public record PullModelRequest(
@JsonProperty("model") String model,
@JsonProperty("insecure") Boolean insecure,
@JsonProperty("username") String username,
@JsonProperty("password") String password,
@JsonProperty("stream") Boolean stream
) {
public PullModelRequest {
if (stream != null && stream) {
logger.warn("Streaming when pulling models is not supported yet");
}
stream = false;
}

public PullModelRequest(String model) {
this(model, null, null, null, null);
}
}

@JsonInclude(Include.NON_NULL)
public record ProgressResponse(
@JsonProperty("status") String status,
@JsonProperty("digest") String digest,
@JsonProperty("total") Long total,
@JsonProperty("completed") Long completed
) {}

/**
* Download a model from the Ollama library. Cancelled pulls are resumed from where they left off,
* and multiple calls will share the same download progress.
*/
public ProgressResponse pullModel(PullModelRequest pullModelRequest) {
return this.restClient.post()
.uri("/api/pull")
.body(pullModelRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.body(ProgressResponse.class);
}

}
// @formatter:on
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
package org.springframework.ai.ollama;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.ollama.api.OllamaApi;
import org.testcontainers.ollama.OllamaContainer;

public class BaseOllamaIT {

private static final Logger logger = LoggerFactory.getLogger(BaseOllamaIT.class);

// Toggle for running tests locally on native Ollama for a faster feedback loop.
private static final boolean useTestcontainers = false;

public static final OllamaContainer ollamaContainer;

static {
Expand All @@ -13,14 +21,34 @@ public class BaseOllamaIT {

/**
* Change the return value to false in order to run multiple Ollama IT tests locally
* reusing the same container image Also add the entry
* reusing the same container image.
*
* Also, add the entry
*
* testcontainers.reuse.enable=true
*
* to the file .testcontainers.properties located in your home directory
* to the file ".testcontainers.properties" located in your home directory
*/
public static boolean isDisabled() {
return true;
return false;
}

public static OllamaApi buildOllamaApiWithModel(String model) {
var baseUrl = "http://localhost:11434";
if (useTestcontainers) {
baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434);
baseUrl = ollamaContainer.getEndpoint();

}
var ollamaApi = new OllamaApi(baseUrl);

ensureModelIsPresent(ollamaApi, model);

return ollamaApi;
}

public static void ensureModelIsPresent(OllamaApi ollamaApi, String model) {
logger.info("Start pulling the '{}' model. The operation can take several minutes...", model);
ollamaApi.pullModel(new OllamaApi.PullModelRequest(model));
logger.info("Completed pulling the '{}' model", model);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package org.springframework.ai.ollama;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledIf;
Expand All @@ -40,7 +39,6 @@
import org.testcontainers.junit.jupiter.Testcontainers;
import reactor.core.publisher.Flux;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
Expand All @@ -54,25 +52,13 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT {

private static final Logger logger = LoggerFactory.getLogger(OllamaChatModelFunctionCallingIT.class);

private static final String MODEL = OllamaModel.MISTRAL.getName();

static String baseUrl = "http://localhost:11434";

@BeforeAll
public static void beforeAll() throws IOException, InterruptedException {
logger.info("Start pulling the '" + MODEL + " ' generative ... would take several minutes ...");
ollamaContainer.execInContainer("ollama", "pull", MODEL);
logger.info(MODEL + " pulling competed!");

baseUrl = "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434);
}
private static final String MODEL = OllamaModel.LLAMA3_2.getName();

@Autowired
ChatModel chatModel;

@Test
void functionCallTest() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Tokyo, and Paris? Return temperatures in Celsius.");

Expand All @@ -97,7 +83,6 @@ void functionCallTest() {
@Disabled("Ollama API does not support streaming function calls yet")
@Test
void streamFunctionCallTest() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Tokyo, and Paris? Return temperatures in Celsius.");

Expand Down Expand Up @@ -132,7 +117,7 @@ static class Config {

@Bean
public OllamaApi ollamaApi() {
return new OllamaApi(baseUrl);
return buildOllamaApiWithModel(MODEL);
}

@Bean
Expand Down
Loading