Skip to content
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public WatsonxAiChatClient(WatsonxAiApi watsonxAiApi) {
.withMaxNewTokens(20)
.withMinNewTokens(0)
.withRepetitionPenalty(1.0f)
.withStopSequences(List.of())
.build());
}

Expand Down Expand Up @@ -114,7 +115,10 @@ public WatsonxAiRequest request(Prompt prompt) {
}

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
if (prompt.getOptions() instanceof WatsonxAiChatOptions runtimeOptions) {
options = ModelOptionsUtils.merge(runtimeOptions, options, WatsonxAiChatOptions.class);
}
else if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, ChatOptions.class,
WatsonxAiChatOptions.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
*/
package org.springframework.ai.watsonx;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonAnyGetter;
import com.fasterxml.jackson.annotation.JsonAnySetter;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;

import org.springframework.ai.chat.prompt.ChatOptions;

/**
Expand All @@ -37,6 +40,7 @@
* valid Parameters and values</a>
*/
// @formatter:off

public class WatsonxAiChatOptions implements ChatOptions {

/**
Expand Down Expand Up @@ -85,14 +89,14 @@ public class WatsonxAiChatOptions implements ChatOptions {
/**
* Sets how many tokens must the LLM generate. (Default: 0)
*/
@JsonProperty("min_new_tokens") private Integer minNewTokens = 0;
@JsonProperty("min_new_tokens") private Integer minNewTokens;

/**
* Sets when the LLM should stop.
* (e.g., ["\n\n\n"]) then when the LLM generates three consecutive line breaks it will terminate.
* Stop sequences are ignored until after the number of tokens that are specified in the Min tokens parameter are generated.
*/
@JsonProperty("stop_sequences") private List<String> stopSequences = List.of();
@JsonProperty("stop_sequences") private List<String> stopSequences;

/**
* Sets how strongly to penalize repetitions. A higher value
Expand All @@ -111,6 +115,14 @@ public class WatsonxAiChatOptions implements ChatOptions {
*/
@JsonProperty("model") private String model;

/**
* Set additional request params (some model have non-predefined options)
*/
@JsonProperty("additional")
private Map<String, Object> additional = new HashMap<>();

@JsonIgnore
private ObjectMapper mapper = new ObjectMapper();

public Float getTemperature() {
return temperature;
Expand Down Expand Up @@ -192,6 +204,20 @@ public void setModel(String model) {
this.model = model;
}

@JsonAnyGetter
public Map<String, Object> getAdditionalProperties() {
return additional.entrySet().stream()
.collect(Collectors.toMap(
entry -> toSnakeCase(entry.getKey()),
Map.Entry::getValue
));
}

@JsonAnySetter
public void addAdditionalProperty(String key, Object value) {
additional.put(key, value);
}

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -250,6 +276,16 @@ public Builder withModel(String model) {
return this;
}

public Builder withAdditionalProperty(String key, Object value) {
this.options.additional.put(key, value);
return this;
}

public Builder withAdditionalProperties(Map<String, Object> properties) {
this.options.additional.putAll(properties);
return this;
}

public WatsonxAiChatOptions build() {
return this.options;
}
Expand All @@ -261,9 +297,11 @@ public WatsonxAiChatOptions build() {
*/
public Map<String, Object> toMap() {
try {
var json = new ObjectMapper().writeValueAsString(this);
return new ObjectMapper().readValue(json, new TypeReference<Map<String, Object>>() {
});
var json = mapper.writeValueAsString(this);
var map = mapper.readValue(json, new TypeReference<Map<String, Object>>() {});
map.remove("additional");

return map;
}
catch (JsonProcessingException e) {
throw new RuntimeException(e);
Expand All @@ -282,5 +320,9 @@ public static Map<String, Object> filterNonSupportedFields(Map<String, Object> o
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

private String toSnakeCase(String input) {
return input != null ? input.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase() : null;
}

}
// @formatter:on
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.watsonx.WatsonxAiChatOptions;
import org.springframework.util.Assert;

// @formatter:off
@JsonInclude(JsonInclude.Include.NON_NULL)
Expand Down Expand Up @@ -62,7 +63,7 @@ public WatsonxAiRequest withProjectId(String projectId) {
public static Builder builder(String input) { return new Builder(input); }

public static class Builder {
private final String input;
public static final String MODEL_PARAMETER_IS_REQUIRED = "Model parameter is required";private final String input;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks awkward.
Please run ./mvnw spring-javaformat:apply

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, apologies, I did the build and it passed so I thought the code linting was right

private Map<String, Object> parameters;
private String model = "";

Expand All @@ -71,6 +72,7 @@ public Builder(String input) {
}

public Builder withParameters(Map<String, Object> parameters) {
Assert.notNull(parameters.get("model"), MODEL_PARAMETER_IS_REQUIRED);
this.model = parameters.get("model").toString();
this.parameters = WatsonxAiChatOptions.filterNonSupportedFields(parameters);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.ai.watsonx.WatsonxAiChatOptions;

import java.util.List;
import java.util.Map;

/**
* @author Pablo Sanchidrian Herrera
Expand Down Expand Up @@ -56,6 +57,68 @@ public void testOptions() {
assertThat(optionsMap).containsEntry("random_seed", 4);
}

@Test
public void testOptionsWithAdditionalParamsOneByOne() {
WatsonxAiChatOptions options = WatsonxAiChatOptions.builder()
.withDecodingMethod("sample")
.withTemperature(1.2f)
.withTopK(20)
.withTopP(0.5f)
.withMaxNewTokens(100)
.withMinNewTokens(20)
.withStopSequences(List.of("\n\n\n"))
.withRepetitionPenalty(1.1f)
.withRandomSeed(4)
.withAdditionalProperty("HAP", true)
.withAdditionalProperty("typicalP", 0.5f)
.build();

var optionsMap = options.toMap();

assertThat(optionsMap).containsEntry("decoding_method", "sample");
assertThat(optionsMap).containsEntry("temperature", 1.2);
assertThat(optionsMap).containsEntry("top_k", 20);
assertThat(optionsMap).containsEntry("top_p", 0.5);
assertThat(optionsMap).containsEntry("max_new_tokens", 100);
assertThat(optionsMap).containsEntry("min_new_tokens", 20);
assertThat(optionsMap).containsEntry("stop_sequences", List.of("\n\n\n"));
assertThat(optionsMap).containsEntry("repetition_penalty", 1.1);
assertThat(optionsMap).containsEntry("random_seed", 4);
assertThat(optionsMap).containsEntry("hap", true);
assertThat(optionsMap).containsEntry("typical_p", 0.5);
}

@Test
public void testOptionsWithAdditionalParamsMap() {
WatsonxAiChatOptions options = WatsonxAiChatOptions.builder()
.withDecodingMethod("sample")
.withTemperature(1.2f)
.withTopK(20)
.withTopP(0.5f)
.withMaxNewTokens(100)
.withMinNewTokens(20)
.withStopSequences(List.of("\n\n\n"))
.withRepetitionPenalty(1.1f)
.withRandomSeed(4)
.withAdditionalProperties(Map.of("HAP", true, "typicalP", 0.5f, "test_value", "test"))
.build();

var optionsMap = options.toMap();

assertThat(optionsMap).containsEntry("decoding_method", "sample");
assertThat(optionsMap).containsEntry("temperature", 1.2);
assertThat(optionsMap).containsEntry("top_k", 20);
assertThat(optionsMap).containsEntry("top_p", 0.5);
assertThat(optionsMap).containsEntry("max_new_tokens", 100);
assertThat(optionsMap).containsEntry("min_new_tokens", 20);
assertThat(optionsMap).containsEntry("stop_sequences", List.of("\n\n\n"));
assertThat(optionsMap).containsEntry("repetition_penalty", 1.1);
assertThat(optionsMap).containsEntry("random_seed", 4);
assertThat(optionsMap).containsEntry("hap", true);
assertThat(optionsMap).containsEntry("typical_p", 0.5);
assertThat(optionsMap).containsEntry("test_value", "test");
}

@Test
public void testFilterOut() {
WatsonxAiChatOptions options = WatsonxAiChatOptions.builder().withModel("google/flan-ul2").build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;

import java.util.List;

/**
* Chat properties for Watsonx.AI Chat.
*
Expand Down Expand Up @@ -48,6 +50,7 @@ public class WatsonxAiChatProperties {
.withMaxNewTokens(20)
.withMinNewTokens(0)
.withRepetitionPenalty(1.0f)
.withStopSequences(List.of())
.build();

public boolean isEnabled() {
Expand Down