diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java index 7f3b56c7e53..76dc6aa408d 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java @@ -246,6 +246,17 @@ public Flux pullModel(PullModelRequest pullModelRequest) { .bodyToFlux(ProgressResponse.class); } + public Flux createModel(CreateModelRequest createModelRequest) { + Assert.notNull(createModelRequest, "createModelRequest must not be null"); + Assert.isTrue(createModelRequest.stream(), "Request must set the stream property to true."); + + return this.webClient.post() + .uri("/api/create") + .bodyValue(createModelRequest) + .retrieve() + .bodyToFlux(ProgressResponse.class); + } + /** * Chat message object. * @@ -741,6 +752,32 @@ public PullModelRequest(String model) { } } + @JsonInclude(Include.NON_NULL) + public record CreateModelRequest( + @JsonProperty("model") String model, + @JsonProperty("from") String from, + @JsonProperty("files") Map files, + @JsonProperty("adapters") Map adapters, + @JsonProperty("template") String template, + @JsonProperty("license") List license, + @JsonProperty("system") String system, + @JsonProperty("parameters") Map parameters, + @JsonProperty("messages") List messages, + @JsonProperty("stream") boolean stream, + @JsonProperty("quantize") String quantize + ) { + public CreateModelRequest { + if (!stream) { + logger.warn("Enforcing streaming of the model creation request"); + } + stream = true; + } + + public CreateModelRequest(String model, String from) { + this(model, from, null, null, null, null, null, null, null, true, null); + } + } + @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ProgressResponse( diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java index e939e69d425..5a23419469b 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/management/OllamaModelManager.java @@ -26,6 +26,7 @@ import org.springframework.ai.ollama.api.OllamaApi.DeleteModelRequest; import org.springframework.ai.ollama.api.OllamaApi.ListModelResponse; import org.springframework.ai.ollama.api.OllamaApi.PullModelRequest; +import org.springframework.ai.ollama.api.OllamaApi.CreateModelRequest; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -127,4 +128,26 @@ public void pullModel(String modelName, PullModelStrategy pullModelStrategy) { // @formatter:on } + public void createModel(String newModelName, String originalModelName) { + // @formatter:off + + logger.info("Start creating model {} from {}", newModelName, originalModelName); + this.ollamaApi.createModel(new CreateModelRequest(newModelName, originalModelName)) + .bufferUntilChanged(OllamaApi.ProgressResponse::status) + .doOnEach(signal -> { + var progressResponses = signal.get(); + if (!CollectionUtils.isEmpty(progressResponses) && progressResponses.get(progressResponses.size() - 1) != null) { + logger.info("Creating the '{}' model - Status: {}", newModelName, progressResponses.get(progressResponses.size() - 1).status()); + } + }) + .takeUntil(progressResponses -> + progressResponses.get(0) != null && "success".equals(progressResponses.get(0).status())) + .timeout(this.options.timeout()) + .retryWhen(Retry.backoff(this.options.maxRetries(), Duration.ofSeconds(5))) + .blockLast(); + logger.info("Completed creating model {} from {}", newModelName, originalModelName); + + // @formatter:on + } + }