diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java index c34335d8b44..cecb6395488 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest; @@ -37,10 +38,13 @@ * this API. If this change in the future we will add it as metadata. * * @author Christian Tzolov + * @author Soby Chacko * @since 0.8.0 */ public class BedrockCohereEmbeddingModel extends AbstractEmbeddingModel { + private static final int COHERE_MAX_TOKENS = 2048; + private final CohereEmbeddingBedrockApi embeddingApi; private final BedrockCohereEmbeddingOptions defaultOptions; @@ -74,11 +78,34 @@ public float[] embed(Document document) { @Override public EmbeddingResponse call(EmbeddingRequest request) { - Assert.notEmpty(request.getInstructions(), "At least one text is required!"); + + List instructions = request.getInstructions(); + Assert.notEmpty(instructions, "At least one text is required!"); final BedrockCohereEmbeddingOptions optionsToUse = this.mergeOptions(request.getOptions()); - var apiRequest = new CohereEmbeddingRequest(request.getInstructions(), optionsToUse.getInputType(), + List truncatedInstructions = instructions.stream().map(text -> { + if (text == null || text.isEmpty()) { + return text; + } + + if (text.length() <= COHERE_MAX_TOKENS) { + return text; + } + + // Handle truncation based on option + return switch (optionsToUse.getTruncate()) { + case END -> text.substring(0, COHERE_MAX_TOKENS); // Keep first 2048 chars + case START -> text.substring(text.length() - COHERE_MAX_TOKENS); // Keep + // last + // 2048 + // chars + default -> text.substring(0, COHERE_MAX_TOKENS); // Default to END + // behavior + }; + }).collect(Collectors.toList()); + + var apiRequest = new CohereEmbeddingRequest(truncatedInstructions, optionsToUse.getInputType(), optionsToUse.getTruncate()); CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest); var indexCounter = new AtomicInteger(0); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java index 03d3f0145c7..01feeb0a24d 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.mockito.ArgumentCaptor; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -31,11 +32,14 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.mock.mockito.SpyBean; import org.springframework.context.annotation.Bean; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.verify; @SpringBootTest @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -45,6 +49,13 @@ class BedrockCohereEmbeddingModelIT { @Autowired private BedrockCohereEmbeddingModel embeddingModel; + @SpyBean + private CohereEmbeddingBedrockApi embeddingApi; + + @Autowired + @Qualifier("embeddingModelStartTruncate") + private BedrockCohereEmbeddingModel embeddingModelStartTruncate; + @Test void singleEmbedding() { assertThat(this.embeddingModel).isNotNull(); @@ -54,6 +65,77 @@ void singleEmbedding() { assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } + @Test + void truncatesLongText() { + String longText = "Hello World".repeat(300); + assertThat(longText.length()).isGreaterThan(2048); + + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of(longText)); + + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); + } + + @Test + void truncatesMultipleLongTexts() { + String longText1 = "Hello World".repeat(300); + String longText2 = "Another Text".repeat(300); + + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of(longText1, longText2)); + + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); + } + + @Test + void verifyExactTruncationLength() { + String longText = "x".repeat(3000); + + ArgumentCaptor requestCaptor = ArgumentCaptor + .forClass(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.class); + + EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of(longText)); + + verify(embeddingApi).embedding(requestCaptor.capture()); + CohereEmbeddingBedrockApi.CohereEmbeddingRequest capturedRequest = requestCaptor.getValue(); + + assertThat(capturedRequest.texts()).hasSize(1); + assertThat(capturedRequest.texts().get(0).length()).isLessThanOrEqualTo(2048); + + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + } + + @Test + void truncatesLongTextFromStart() { + String startMarker = "START_MARKER_"; + String endMarker = "_END_MARKER"; + String middlePadding = "x".repeat(2500); // Long enough to force truncation + String longText = startMarker + middlePadding + endMarker; + + assertThat(longText.length()).isGreaterThan(2048); + + ArgumentCaptor requestCaptor = ArgumentCaptor + .forClass(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.class); + + EmbeddingResponse embeddingResponse = this.embeddingModelStartTruncate.embedForResponse(List.of(longText)); + + // Verify truncation behavior + verify(embeddingApi).embedding(requestCaptor.capture()); + String truncatedText = requestCaptor.getValue().texts().get(0); + assertThat(truncatedText.length()).isLessThanOrEqualTo(2048); + assertThat(truncatedText).doesNotContain(startMarker); + assertThat(truncatedText).endsWith(endMarker); + + // Verify embedding response + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(this.embeddingModelStartTruncate.dimensions()).isEqualTo(1024); + } + @Test void batchEmbedding() { assertThat(this.embeddingModel).isNotNull(); @@ -93,9 +175,27 @@ public CohereEmbeddingBedrockApi cohereEmbeddingApi() { Duration.ofMinutes(2)); } - @Bean + @Bean("embeddingModel") public BedrockCohereEmbeddingModel cohereAiEmbedding(CohereEmbeddingBedrockApi cohereEmbeddingApi) { - return new BedrockCohereEmbeddingModel(cohereEmbeddingApi); + // custom model that uses the END truncation strategy, instead of the default + // NONE. + return new BedrockCohereEmbeddingModel(cohereEmbeddingApi, + BedrockCohereEmbeddingOptions.builder() + .withInputType(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT) + .withTruncate(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.Truncate.END) + .build()); + } + + @Bean("embeddingModelStartTruncate") + public BedrockCohereEmbeddingModel cohereAiEmbeddingStartTruncate( + CohereEmbeddingBedrockApi cohereEmbeddingApi) { + // custom model that uses the START truncation strategy, instead of the + // default NONE. + return new BedrockCohereEmbeddingModel(cohereEmbeddingApi, + BedrockCohereEmbeddingOptions.builder() + .withInputType(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT) + .withTruncate(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.Truncate.START) + .build()); } }