Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class BedrockTitanChatModel implements ChatModel, StreamingChatModel {
private final BedrockTitanChatOptions defaultOptions;

public BedrockTitanChatModel(TitanChatBedrockApi chatApi) {
this(chatApi, BedrockTitanChatOptions.builder().withTemperature(0.8).build());
this(chatApi, BedrockTitanChatOptions.builder().temperature(0.8).build());
}

public BedrockTitanChatModel(TitanChatBedrockApi chatApi, BedrockTitanChatOptions defaultOptions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ public static Builder builder() {
}

public static BedrockTitanChatOptions fromOptions(BedrockTitanChatOptions fromOptions) {
return builder().withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withMaxTokenCount(fromOptions.getMaxTokenCount())
.withStopSequences(fromOptions.getStopSequences())
return builder().temperature(fromOptions.getTemperature())
.topP(fromOptions.getTopP())
.maxTokenCount(fromOptions.getMaxTokenCount())
.stopSequences(fromOptions.getStopSequences())
.build();
}

Expand Down Expand Up @@ -148,21 +148,57 @@ public static class Builder {

private BedrockTitanChatOptions options = new BedrockTitanChatOptions();

public Builder temperature(Double temperature) {
this.options.temperature = temperature;
return this;
}

public Builder topP(Double topP) {
this.options.topP = topP;
return this;
}

public Builder maxTokenCount(Integer maxTokenCount) {
this.options.maxTokenCount = maxTokenCount;
return this;
}

public Builder stopSequences(List<String> stopSequences) {
this.options.stopSequences = stopSequences;
return this;
}

/**
* @deprecated see {@link #temperature(Double)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withTemperature(Double temperature) {
this.options.temperature = temperature;
return this;
}

/**
* @deprecated see {@link #topP(Double)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withTopP(Double topP) {
this.options.topP = topP;
return this;
}

/**
* @deprecated see {@link #maxTokenCount(Integer)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withMaxTokenCount(Integer maxTokenCount) {
this.options.maxTokenCount = maxTokenCount;
return this;
}

/**
* @deprecated see {@link #stopSequences(List)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withStopSequences(List<String> stopSequences) {
this.options.stopSequences = stopSequences;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.ai.bedrock.titan;

import java.util.List;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
Expand Down Expand Up @@ -66,6 +68,17 @@ public static class Builder {

private BedrockTitanEmbeddingOptions options = new BedrockTitanEmbeddingOptions();

public Builder inputType(InputType inputType) {
Assert.notNull(inputType, "input type can not be null.");

this.options.setInputType(inputType);
return this;
}

/**
* @deprecated see {@link #inputType(InputType)} (List)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withInputType(InputType inputType) {
Assert.notNull(inputType, "input type can not be null.");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public BedrockAi21Jurassic2ChatModel bedrockAi21Jurassic2ChatModel(
BedrockAi21Jurassic2ChatOptions.builder()
.temperature(0.5)
.maxTokens(500)
// .withTopP(0.9)
// .topP(0.9)
.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ public void createRequestWithChatOptions() {

var model = new BedrockTitanChatModel(this.api,
BedrockTitanChatOptions.builder()
.withTemperature(66.6)
.withTopP(0.66)
.withMaxTokenCount(666)
.withStopSequences(List.of("stop1", "stop2"))
.temperature(66.6)
.topP(0.66)
.maxTokenCount(666)
.stopSequences(List.of("stop1", "stop2"))
.build());

var request = model.createRequest(new Prompt("Test message content"));
Expand All @@ -60,10 +60,10 @@ public void createRequestWithChatOptions() {

request = model.createRequest(new Prompt("Test message content",
BedrockTitanChatOptions.builder()
.withTemperature(99.9)
.withTopP(0.99)
.withMaxTokenCount(999)
.withStopSequences(List.of("stop3", "stop4"))
.temperature(99.9)
.topP(0.99)
.maxTokenCount(999)
.stopSequences(List.of("stop3", "stop4"))
.build()

));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class BedrockTitanEmbeddingModelIT {
void singleEmbedding() {
assertThat(this.embeddingModel).isNotNull();
EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of("Hello World"),
BedrockTitanEmbeddingOptions.builder().withInputType(InputType.TEXT).build()));
BedrockTitanEmbeddingOptions.builder().inputType(InputType.TEXT).build()));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
Expand All @@ -65,7 +65,7 @@ void imageEmbedding() throws IOException {

EmbeddingResponse embeddingResponse = this.embeddingModel
.call(new EmbeddingRequest(List.of(Base64.getEncoder().encodeToString(image)),
BedrockTitanEmbeddingOptions.builder().withInputType(InputType.IMAGE).build()));
BedrockTitanEmbeddingOptions.builder().inputType(InputType.IMAGE).build()));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public class BedrockAnthropic3ChatProperties {
.maxTokens(300)
.topK(10)
.anthropicVersion(Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION)
// .withStopSequences(List.of("\n\nHuman:"))
// .stopSequences(List.of("\n\nHuman:"))
.build();

public boolean isEnabled() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public class BedrockTitanChatProperties {
private String model = TitanChatModel.TITAN_TEXT_EXPRESS_V1.id();

@NestedConfigurationProperty
private BedrockTitanChatOptions options = BedrockTitanChatOptions.builder().withTemperature(0.7).build();
private BedrockTitanChatOptions options = BedrockTitanChatOptions.builder().temperature(0.7).build();

public boolean isEnabled() {
return this.enabled;
Expand Down
Loading