Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ public static class Builder {

private BedrockTitanEmbeddingOptions options = new BedrockTitanEmbeddingOptions();

@Deprecated
public Builder withInputType(InputType inputType) {
return this.inputType(inputType);
}

public Builder inputType(InputType inputType) {
Assert.notNull(inputType, "input type can not be null.");

this.options.setInputType(inputType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class BedrockTitanEmbeddingModelIT {
void singleEmbedding() {
assertThat(this.embeddingModel).isNotNull();
EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of("Hello World"),
BedrockTitanEmbeddingOptions.builder().withInputType(InputType.TEXT).build()));
BedrockTitanEmbeddingOptions.builder().inputType(InputType.TEXT).build()));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
Expand All @@ -69,7 +69,7 @@ void imageEmbedding() throws IOException {

EmbeddingResponse embeddingResponse = this.embeddingModel
.call(new EmbeddingRequest(List.of(Base64.getEncoder().encodeToString(image)),
BedrockTitanEmbeddingOptions.builder().withInputType(InputType.IMAGE).build()));
BedrockTitanEmbeddingOptions.builder().inputType(InputType.IMAGE).build()));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ public class TransformersEmbeddingModelObservationTests {

@Test
void observationForEmbeddingOperation() {

var options = EmbeddingOptionsBuilder.builder().withModel("bert-base-uncased").build();
var options = EmbeddingOptionsBuilder.builder().model("bert-base-uncased").build();

EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,26 @@ public static EmbeddingOptionsBuilder builder() {
return new EmbeddingOptionsBuilder();
}

public EmbeddingOptionsBuilder withModel(String model) {
public EmbeddingOptionsBuilder model(String model) {
this.embeddingOptions.setModel(model);
return this;
}

public EmbeddingOptionsBuilder withDimensions(Integer dimensions) {
@Deprecated
public EmbeddingOptionsBuilder withModel(String model) {
return model(model);
}

public EmbeddingOptionsBuilder dimensions(Integer dimensions) {
this.embeddingOptions.setDimensions(dimensions);
return this;
}

@Deprecated
public EmbeddingOptionsBuilder withDimensions(Integer dimensions) {
return dimensions(dimensions);
}

public EmbeddingOptions build() {
return this.embeddingOptions;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void shouldHaveName() {
@Test
void contextualNameWhenModelIsDefined() {
EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().withModel("mistral").build()))
.embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().model("mistral").build()))
.provider("superprovider")
.build();
assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("embedding mistral");
Expand All @@ -71,8 +71,7 @@ void contextualNameWhenModelIsNotDefined() {
@Test
void supportsOnlyEmbeddingModelObservationContext() {
EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(
generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().withModel("supermodel").build()))
.embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().model("supermodel").build()))
.provider("superprovider")
.build();
assertThat(this.observationConvention.supportsContext(observationContext)).isTrue();
Expand All @@ -82,7 +81,7 @@ void supportsOnlyEmbeddingModelObservationContext() {
@Test
void shouldHaveLowCardinalityKeyValuesWhenDefined() {
EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().withModel("mistral").build()))
.embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().model("mistral").build()))
.provider("superprovider")
.build();
assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains(
Expand All @@ -95,7 +94,7 @@ void shouldHaveLowCardinalityKeyValuesWhenDefined() {
void shouldHaveLowCardinalityKeyValuesWhenDefinedAndResponse() {
EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(generateEmbeddingRequest(
EmbeddingOptionsBuilder.builder().withModel("mistral").withDimensions(1492).build()))
EmbeddingOptionsBuilder.builder().model("mistral").dimensions(1492).build()))
.provider("superprovider")
.build();
observationContext.setResponse(new EmbeddingResponse(List.of(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void shouldCreateAllMetersDuringAnObservation() {

private EmbeddingModelObservationContext generateObservationContext() {
return EmbeddingModelObservationContext.builder()
.embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().withModel("mistral").build()))
.embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().model("mistral").build()))
.provider("superprovider")
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ class EmbeddingModelObservationContextTests {
@Test
void whenMandatoryRequestOptionsThenReturn() {
var observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(
generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().withModel("supermodel").build()))
.embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().model("supermodel").build()))
.provider("superprovider")
.build();

Expand Down