Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest;
Expand All @@ -37,10 +38,13 @@
* this API. If this change in the future we will add it as metadata.
*
* @author Christian Tzolov
* @author Soby Chacko
* @since 0.8.0
*/
public class BedrockCohereEmbeddingModel extends AbstractEmbeddingModel {

private static final int COHERE_MAX_TOKENS = 2048;

private final CohereEmbeddingBedrockApi embeddingApi;

private final BedrockCohereEmbeddingOptions defaultOptions;
Expand Down Expand Up @@ -74,11 +78,34 @@ public float[] embed(Document document) {

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");

List<String> instructions = request.getInstructions();
Assert.notEmpty(instructions, "At least one text is required!");

final BedrockCohereEmbeddingOptions optionsToUse = this.mergeOptions(request.getOptions());

var apiRequest = new CohereEmbeddingRequest(request.getInstructions(), optionsToUse.getInputType(),
List<String> truncatedInstructions = instructions.stream().map(text -> {
if (text == null || text.isEmpty()) {
return text;
}

if (text.length() <= COHERE_MAX_TOKENS) {
return text;
}

// Handle truncation based on option
return switch (optionsToUse.getTruncate()) {
case END -> text.substring(0, COHERE_MAX_TOKENS); // Keep first 2048 chars
case START -> text.substring(text.length() - COHERE_MAX_TOKENS); // Keep
// last
// 2048
// chars
default -> text.substring(0, COHERE_MAX_TOKENS); // Default to END
// behavior
};
}).collect(Collectors.toList());

var apiRequest = new CohereEmbeddingRequest(truncatedInstructions, optionsToUse.getInputType(),
optionsToUse.getTruncate());
CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest);
var indexCounter = new AtomicInteger(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.mockito.ArgumentCaptor;
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
import software.amazon.awssdk.regions.Region;

Expand All @@ -31,11 +32,14 @@
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.SpyBean;
import org.springframework.context.annotation.Bean;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.verify;

@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
Expand All @@ -45,6 +49,13 @@ class BedrockCohereEmbeddingModelIT {
@Autowired
private BedrockCohereEmbeddingModel embeddingModel;

@SpyBean
private CohereEmbeddingBedrockApi embeddingApi;

@Autowired
@Qualifier("embeddingModelStartTruncate")
private BedrockCohereEmbeddingModel embeddingModelStartTruncate;

@Test
void singleEmbedding() {
assertThat(this.embeddingModel).isNotNull();
Expand All @@ -54,6 +65,77 @@ void singleEmbedding() {
assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
}

@Test
void truncatesLongText() {
String longText = "Hello World".repeat(300);
assertThat(longText.length()).isGreaterThan(2048);

EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of(longText));

assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
}

@Test
void truncatesMultipleLongTexts() {
String longText1 = "Hello World".repeat(300);
String longText2 = "Another Text".repeat(300);

EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of(longText1, longText2));

assertThat(embeddingResponse.getResults()).hasSize(2);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
}

@Test
void verifyExactTruncationLength() {
String longText = "x".repeat(3000);

ArgumentCaptor<CohereEmbeddingBedrockApi.CohereEmbeddingRequest> requestCaptor = ArgumentCaptor
.forClass(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.class);

EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of(longText));

verify(embeddingApi).embedding(requestCaptor.capture());
CohereEmbeddingBedrockApi.CohereEmbeddingRequest capturedRequest = requestCaptor.getValue();

assertThat(capturedRequest.texts()).hasSize(1);
assertThat(capturedRequest.texts().get(0).length()).isLessThanOrEqualTo(2048);

assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
}

@Test
void truncatesLongTextFromStart() {
String startMarker = "START_MARKER_";
String endMarker = "_END_MARKER";
String middlePadding = "x".repeat(2500); // Long enough to force truncation
String longText = startMarker + middlePadding + endMarker;

assertThat(longText.length()).isGreaterThan(2048);

ArgumentCaptor<CohereEmbeddingBedrockApi.CohereEmbeddingRequest> requestCaptor = ArgumentCaptor
.forClass(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.class);

EmbeddingResponse embeddingResponse = this.embeddingModelStartTruncate.embedForResponse(List.of(longText));

// Verify truncation behavior
verify(embeddingApi).embedding(requestCaptor.capture());
String truncatedText = requestCaptor.getValue().texts().get(0);
assertThat(truncatedText.length()).isLessThanOrEqualTo(2048);
assertThat(truncatedText).doesNotContain(startMarker);
assertThat(truncatedText).endsWith(endMarker);

// Verify embedding response
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(this.embeddingModelStartTruncate.dimensions()).isEqualTo(1024);
}

@Test
void batchEmbedding() {
assertThat(this.embeddingModel).isNotNull();
Expand Down Expand Up @@ -93,9 +175,27 @@ public CohereEmbeddingBedrockApi cohereEmbeddingApi() {
Duration.ofMinutes(2));
}

@Bean
@Bean("embeddingModel")
public BedrockCohereEmbeddingModel cohereAiEmbedding(CohereEmbeddingBedrockApi cohereEmbeddingApi) {
return new BedrockCohereEmbeddingModel(cohereEmbeddingApi);
// custom model that uses the END truncation strategy, instead of the default
// NONE.
return new BedrockCohereEmbeddingModel(cohereEmbeddingApi,
BedrockCohereEmbeddingOptions.builder()
.withInputType(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT)
.withTruncate(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.Truncate.END)
.build());
}

@Bean("embeddingModelStartTruncate")
public BedrockCohereEmbeddingModel cohereAiEmbeddingStartTruncate(
CohereEmbeddingBedrockApi cohereEmbeddingApi) {
// custom model that uses the START truncation strategy, instead of the
// default NONE.
return new BedrockCohereEmbeddingModel(cohereEmbeddingApi,
BedrockCohereEmbeddingOptions.builder()
.withInputType(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT)
.withTruncate(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.Truncate.START)
.build());
}

}
Expand Down