Skip to content

Commit 1db04c3

Browse files
committed
Ollama: Added missing fields in API
Signed-off-by: Thomas Vitale <[email protected]>
1 parent ae6a019 commit 1db04c3

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ OllamaApi.EmbeddingRequest ollamaEmbeddingRequest(String inputContent, Embedding
135135
throw new IllegalArgumentException("Model is not set!");
136136
}
137137
String model = mergedOptions.getModel();
138-
return new EmbeddingRequest(model, inputContent, OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()));
138+
return new EmbeddingRequest(model, inputContent, null,
139+
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()));
139140
}
140141

141142
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ public record GenerateRequest(
150150
@JsonProperty("template") String template,
151151
@JsonProperty("context") List<Integer> context,
152152
@JsonProperty("stream") Boolean stream,
153-
@JsonProperty("raw") Boolean raw) {
153+
@JsonProperty("raw") Boolean raw,
154+
@JsonProperty("images") List<String> images,
155+
@JsonProperty("keep_alive") Duration keepAlive) {
154156

155157
/**
156158
* Short cut constructor to create a CompletionRequest without options.
@@ -159,7 +161,7 @@ public record GenerateRequest(
159161
* @param stream Whether to stream the response.
160162
*/
161163
public GenerateRequest(String model, String prompt, Boolean stream) {
162-
this(model, prompt, null, null, null, null, null, stream, null);
164+
this(model, prompt, null, null, null, null, null, stream, null, null, null);
163165
}
164166

165167
/**
@@ -170,7 +172,7 @@ public GenerateRequest(String model, String prompt, Boolean stream) {
170172
* @param stream Whether to stream the response.
171173
*/
172174
public GenerateRequest(String model, String prompt, boolean enableJsonFormat, Boolean stream) {
173-
this(model, prompt, (enableJsonFormat) ? "json" : null, null, null, null, null, stream, null);
175+
this(model, prompt, (enableJsonFormat) ? "json" : null, null, null, null, null, stream, null, null, null);
174176
}
175177

176178
/**
@@ -192,6 +194,8 @@ public static class Builder {
192194
private List<Integer> context;
193195
private Boolean stream;
194196
private Boolean raw;
197+
private List<String> images;
198+
private Duration keepAlive;
195199

196200
public Builder(String prompt) {
197201
this.prompt = prompt;
@@ -242,8 +246,18 @@ public Builder withRaw(Boolean raw) {
242246
return this;
243247
}
244248

249+
public Builder withImages(List<String> images) {
250+
this.images = images;
251+
return this;
252+
}
253+
254+
public Builder withKeepAlive(Duration keepAlive) {
255+
this.keepAlive = keepAlive;
256+
return this;
257+
}
258+
245259
public GenerateRequest build() {
246-
return new GenerateRequest(model, prompt, format, options, system, template, context, stream, raw);
260+
return new GenerateRequest(model, prompt, format, options, system, template, context, stream, raw, images, keepAlive);
247261
}
248262

249263
}
@@ -411,6 +425,7 @@ public record ChatRequest(
411425
@JsonProperty("messages") List<Message> messages,
412426
@JsonProperty("stream") Boolean stream,
413427
@JsonProperty("format") String format,
428+
@JsonProperty("keep_alive") Duration keepAlive,
414429
@JsonProperty("options") Map<String, Object> options) {
415430

416431
public static Builder builder(String model) {
@@ -423,6 +438,7 @@ public static class Builder {
423438
private List<Message> messages = List.of();
424439
private boolean stream = false;
425440
private String format;
441+
private Duration keepAlive;
426442
private Map<String, Object> options = Map.of();
427443

428444
public Builder(String model) {
@@ -445,21 +461,26 @@ public Builder withFormat(String format) {
445461
return this;
446462
}
447463

464+
public Builder withKeepAlive(Duration keepAlive) {
465+
this.keepAlive = keepAlive;
466+
return this;
467+
}
468+
448469
public Builder withOptions(Map<String, Object> options) {
449-
Objects.requireNonNullElse(options, "The options can not be null.");
470+
Objects.requireNonNull(options, "The options can not be null.");
450471

451472
this.options = OllamaOptions.filterNonSupportedFields(options);
452473
return this;
453474
}
454475

455476
public Builder withOptions(OllamaOptions options) {
456-
Objects.requireNonNullElse(options, "The options can not be null.");
477+
Objects.requireNonNull(options, "The options can not be null.");
457478
this.options = OllamaOptions.filterNonSupportedFields(options.toMap());
458479
return this;
459480
}
460481

461482
public ChatRequest build() {
462-
return new ChatRequest(model, messages, stream, format, options);
483+
return new ChatRequest(model, messages, stream, format, keepAlive, options);
463484
}
464485
}
465486
}
@@ -558,6 +579,7 @@ public Flux<ChatResponse> streamingChat(ChatRequest chatRequest) {
558579
public record EmbeddingRequest(
559580
@JsonProperty("model") String model,
560581
@JsonProperty("prompt") String prompt,
582+
@JsonProperty("keep_alive") Duration keepAlive,
561583
@JsonProperty("options") Map<String, Object> options) {
562584

563585
/**
@@ -566,7 +588,7 @@ public record EmbeddingRequest(
566588
* @param prompt The text to generate embeddings for.
567589
*/
568590
public EmbeddingRequest(String model, String prompt) {
569-
this(model, prompt, null);
591+
this(model, prompt, null, null);
570592
}
571593
}
572594

0 commit comments

Comments
 (0)