diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatClient.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatClient.java index 6ec0c5072a1..b4f5c14ff30 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatClient.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatClient.java @@ -64,6 +64,7 @@ public WatsonxAiChatClient(WatsonxAiApi watsonxAiApi) { .withMaxNewTokens(20) .withMinNewTokens(0) .withRepetitionPenalty(1.0f) + .withStopSequences(List.of()) .build()); } @@ -114,7 +115,10 @@ public WatsonxAiRequest request(Prompt prompt) { } if (prompt.getOptions() != null) { - if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { + if (prompt.getOptions() instanceof WatsonxAiChatOptions runtimeOptions) { + options = ModelOptionsUtils.merge(runtimeOptions, options, WatsonxAiChatOptions.class); + } + else if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, ChatOptions.class, WatsonxAiChatOptions.class); diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java index 14943504973..0b10febdbaf 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java @@ -15,15 +15,18 @@ */ package org.springframework.ai.watsonx; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonAnyGetter; +import com.fasterxml.jackson.annotation.JsonAnySetter; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; - import org.springframework.ai.chat.prompt.ChatOptions; /** @@ -37,6 +40,7 @@ * valid Parameters and values */ // @formatter:off + public class WatsonxAiChatOptions implements ChatOptions { /** @@ -85,14 +89,14 @@ public class WatsonxAiChatOptions implements ChatOptions { /** * Sets how many tokens must the LLM generate. (Default: 0) */ - @JsonProperty("min_new_tokens") private Integer minNewTokens = 0; + @JsonProperty("min_new_tokens") private Integer minNewTokens; /** * Sets when the LLM should stop. * (e.g., ["\n\n\n"]) then when the LLM generates three consecutive line breaks it will terminate. * Stop sequences are ignored until after the number of tokens that are specified in the Min tokens parameter are generated. */ - @JsonProperty("stop_sequences") private List stopSequences = List.of(); + @JsonProperty("stop_sequences") private List stopSequences; /** * Sets how strongly to penalize repetitions. A higher value @@ -111,6 +115,14 @@ public class WatsonxAiChatOptions implements ChatOptions { */ @JsonProperty("model") private String model; + /** + * Set additional request params (some model have non-predefined options) + */ + @JsonProperty("additional") + private Map additional = new HashMap<>(); + + @JsonIgnore + private ObjectMapper mapper = new ObjectMapper(); public Float getTemperature() { return temperature; @@ -192,6 +204,20 @@ public void setModel(String model) { this.model = model; } + @JsonAnyGetter + public Map getAdditionalProperties() { + return additional.entrySet().stream() + .collect(Collectors.toMap( + entry -> toSnakeCase(entry.getKey()), + Map.Entry::getValue + )); + } + + @JsonAnySetter + public void addAdditionalProperty(String key, Object value) { + additional.put(key, value); + } + public static Builder builder() { return new Builder(); } @@ -250,6 +276,16 @@ public Builder withModel(String model) { return this; } + public Builder withAdditionalProperty(String key, Object value) { + this.options.additional.put(key, value); + return this; + } + + public Builder withAdditionalProperties(Map properties) { + this.options.additional.putAll(properties); + return this; + } + public WatsonxAiChatOptions build() { return this.options; } @@ -261,9 +297,11 @@ public WatsonxAiChatOptions build() { */ public Map toMap() { try { - var json = new ObjectMapper().writeValueAsString(this); - return new ObjectMapper().readValue(json, new TypeReference>() { - }); + var json = mapper.writeValueAsString(this); + var map = mapper.readValue(json, new TypeReference>() {}); + map.remove("additional"); + + return map; } catch (JsonProcessingException e) { throw new RuntimeException(e); @@ -282,5 +320,9 @@ public static Map filterNonSupportedFields(Map o .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); } + private String toSnakeCase(String input) { + return input != null ? input.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase() : null; + } + } // @formatter:on \ No newline at end of file diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiRequest.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiRequest.java index 7f67df726ec..2ca88e25864 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiRequest.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiRequest.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.springframework.ai.watsonx.WatsonxAiChatOptions; +import org.springframework.util.Assert; // @formatter:off @JsonInclude(JsonInclude.Include.NON_NULL) @@ -62,6 +63,7 @@ public WatsonxAiRequest withProjectId(String projectId) { public static Builder builder(String input) { return new Builder(input); } public static class Builder { + public static final String MODEL_PARAMETER_IS_REQUIRED = "Model parameter is required"; private final String input; private Map parameters; private String model = ""; @@ -71,6 +73,7 @@ public Builder(String input) { } public Builder withParameters(Map parameters) { + Assert.notNull(parameters.get("model"), MODEL_PARAMETER_IS_REQUIRED); this.model = parameters.get("model").toString(); this.parameters = WatsonxAiChatOptions.filterNonSupportedFields(parameters); return this; diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java index de5347fa091..5011cc3e07a 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java @@ -22,6 +22,7 @@ import org.springframework.ai.watsonx.WatsonxAiChatOptions; import java.util.List; +import java.util.Map; /** * @author Pablo Sanchidrian Herrera @@ -56,6 +57,68 @@ public void testOptions() { assertThat(optionsMap).containsEntry("random_seed", 4); } + @Test + public void testOptionsWithAdditionalParamsOneByOne() { + WatsonxAiChatOptions options = WatsonxAiChatOptions.builder() + .withDecodingMethod("sample") + .withTemperature(1.2f) + .withTopK(20) + .withTopP(0.5f) + .withMaxNewTokens(100) + .withMinNewTokens(20) + .withStopSequences(List.of("\n\n\n")) + .withRepetitionPenalty(1.1f) + .withRandomSeed(4) + .withAdditionalProperty("HAP", true) + .withAdditionalProperty("typicalP", 0.5f) + .build(); + + var optionsMap = options.toMap(); + + assertThat(optionsMap).containsEntry("decoding_method", "sample"); + assertThat(optionsMap).containsEntry("temperature", 1.2); + assertThat(optionsMap).containsEntry("top_k", 20); + assertThat(optionsMap).containsEntry("top_p", 0.5); + assertThat(optionsMap).containsEntry("max_new_tokens", 100); + assertThat(optionsMap).containsEntry("min_new_tokens", 20); + assertThat(optionsMap).containsEntry("stop_sequences", List.of("\n\n\n")); + assertThat(optionsMap).containsEntry("repetition_penalty", 1.1); + assertThat(optionsMap).containsEntry("random_seed", 4); + assertThat(optionsMap).containsEntry("hap", true); + assertThat(optionsMap).containsEntry("typical_p", 0.5); + } + + @Test + public void testOptionsWithAdditionalParamsMap() { + WatsonxAiChatOptions options = WatsonxAiChatOptions.builder() + .withDecodingMethod("sample") + .withTemperature(1.2f) + .withTopK(20) + .withTopP(0.5f) + .withMaxNewTokens(100) + .withMinNewTokens(20) + .withStopSequences(List.of("\n\n\n")) + .withRepetitionPenalty(1.1f) + .withRandomSeed(4) + .withAdditionalProperties(Map.of("HAP", true, "typicalP", 0.5f, "test_value", "test")) + .build(); + + var optionsMap = options.toMap(); + + assertThat(optionsMap).containsEntry("decoding_method", "sample"); + assertThat(optionsMap).containsEntry("temperature", 1.2); + assertThat(optionsMap).containsEntry("top_k", 20); + assertThat(optionsMap).containsEntry("top_p", 0.5); + assertThat(optionsMap).containsEntry("max_new_tokens", 100); + assertThat(optionsMap).containsEntry("min_new_tokens", 20); + assertThat(optionsMap).containsEntry("stop_sequences", List.of("\n\n\n")); + assertThat(optionsMap).containsEntry("repetition_penalty", 1.1); + assertThat(optionsMap).containsEntry("random_seed", 4); + assertThat(optionsMap).containsEntry("hap", true); + assertThat(optionsMap).containsEntry("typical_p", 0.5); + assertThat(optionsMap).containsEntry("test_value", "test"); + } + @Test public void testFilterOut() { WatsonxAiChatOptions options = WatsonxAiChatOptions.builder().withModel("google/flan-ul2").build(); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java index f19579d5598..3da222e9309 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java @@ -19,6 +19,8 @@ import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; +import java.util.List; + /** * Chat properties for Watsonx.AI Chat. * @@ -48,6 +50,7 @@ public class WatsonxAiChatProperties { .withMaxNewTokens(20) .withMinNewTokens(0) .withRepetitionPenalty(1.0f) + .withStopSequences(List.of()) .build(); public boolean isEnabled() {