From de2927d31f0e270fdaf3d4567cbd2e32dccb90e9 Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Tue, 19 Nov 2024 21:01:55 -0500 Subject: [PATCH] GH-1753: Add truncation support for Cohere embeddings Fixes: #1753 https://github.com/spring-projects/spring-ai/issues/1753 - Add character-based truncation (max 2048 chars) for Cohere embedding requests - Support both START and END truncation strategies - Add unit tests verifying truncation behavior for both strategies Truncation is applied before sending requests to Bedrock API to avoid ValidationException when text exceeds maximum length. The END strategy (default) keeps the first 2048 characters while START keeps the last 2048 characters. --- .../cohere/BedrockCohereEmbeddingModel.java | 31 +++++- .../cohere/BedrockCohereEmbeddingModelIT.java | 104 +++++++++++++++++- 2 files changed, 131 insertions(+), 4 deletions(-) diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java index c34335d8b44..cecb6395488 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest; @@ -37,10 +38,13 @@ * this API. If this change in the future we will add it as metadata. * * @author Christian Tzolov + * @author Soby Chacko * @since 0.8.0 */ public class BedrockCohereEmbeddingModel extends AbstractEmbeddingModel { + private static final int COHERE_MAX_TOKENS = 2048; + private final CohereEmbeddingBedrockApi embeddingApi; private final BedrockCohereEmbeddingOptions defaultOptions; @@ -74,11 +78,34 @@ public float[] embed(Document document) { @Override public EmbeddingResponse call(EmbeddingRequest request) { - Assert.notEmpty(request.getInstructions(), "At least one text is required!"); + + List instructions = request.getInstructions(); + Assert.notEmpty(instructions, "At least one text is required!"); final BedrockCohereEmbeddingOptions optionsToUse = this.mergeOptions(request.getOptions()); - var apiRequest = new CohereEmbeddingRequest(request.getInstructions(), optionsToUse.getInputType(), + List truncatedInstructions = instructions.stream().map(text -> { + if (text == null || text.isEmpty()) { + return text; + } + + if (text.length() <= COHERE_MAX_TOKENS) { + return text; + } + + // Handle truncation based on option + return switch (optionsToUse.getTruncate()) { + case END -> text.substring(0, COHERE_MAX_TOKENS); // Keep first 2048 chars + case START -> text.substring(text.length() - COHERE_MAX_TOKENS); // Keep + // last + // 2048 + // chars + default -> text.substring(0, COHERE_MAX_TOKENS); // Default to END + // behavior + }; + }).collect(Collectors.toList()); + + var apiRequest = new CohereEmbeddingRequest(truncatedInstructions, optionsToUse.getInputType(), optionsToUse.getTruncate()); CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest); var indexCounter = new AtomicInteger(0); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java index 03d3f0145c7..01feeb0a24d 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModelIT.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.mockito.ArgumentCaptor; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -31,11 +32,14 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.mock.mockito.SpyBean; import org.springframework.context.annotation.Bean; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.verify; @SpringBootTest @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -45,6 +49,13 @@ class BedrockCohereEmbeddingModelIT { @Autowired private BedrockCohereEmbeddingModel embeddingModel; + @SpyBean + private CohereEmbeddingBedrockApi embeddingApi; + + @Autowired + @Qualifier("embeddingModelStartTruncate") + private BedrockCohereEmbeddingModel embeddingModelStartTruncate; + @Test void singleEmbedding() { assertThat(this.embeddingModel).isNotNull(); @@ -54,6 +65,77 @@ void singleEmbedding() { assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); } + @Test + void truncatesLongText() { + String longText = "Hello World".repeat(300); + assertThat(longText.length()).isGreaterThan(2048); + + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of(longText)); + + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); + } + + @Test + void truncatesMultipleLongTexts() { + String longText1 = "Hello World".repeat(300); + String longText2 = "Another Text".repeat(300); + + EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of(longText1, longText2)); + + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(this.embeddingModel.dimensions()).isEqualTo(1024); + } + + @Test + void verifyExactTruncationLength() { + String longText = "x".repeat(3000); + + ArgumentCaptor requestCaptor = ArgumentCaptor + .forClass(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.class); + + EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of(longText)); + + verify(embeddingApi).embedding(requestCaptor.capture()); + CohereEmbeddingBedrockApi.CohereEmbeddingRequest capturedRequest = requestCaptor.getValue(); + + assertThat(capturedRequest.texts()).hasSize(1); + assertThat(capturedRequest.texts().get(0).length()).isLessThanOrEqualTo(2048); + + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + } + + @Test + void truncatesLongTextFromStart() { + String startMarker = "START_MARKER_"; + String endMarker = "_END_MARKER"; + String middlePadding = "x".repeat(2500); // Long enough to force truncation + String longText = startMarker + middlePadding + endMarker; + + assertThat(longText.length()).isGreaterThan(2048); + + ArgumentCaptor requestCaptor = ArgumentCaptor + .forClass(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.class); + + EmbeddingResponse embeddingResponse = this.embeddingModelStartTruncate.embedForResponse(List.of(longText)); + + // Verify truncation behavior + verify(embeddingApi).embedding(requestCaptor.capture()); + String truncatedText = requestCaptor.getValue().texts().get(0); + assertThat(truncatedText.length()).isLessThanOrEqualTo(2048); + assertThat(truncatedText).doesNotContain(startMarker); + assertThat(truncatedText).endsWith(endMarker); + + // Verify embedding response + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(this.embeddingModelStartTruncate.dimensions()).isEqualTo(1024); + } + @Test void batchEmbedding() { assertThat(this.embeddingModel).isNotNull(); @@ -93,9 +175,27 @@ public CohereEmbeddingBedrockApi cohereEmbeddingApi() { Duration.ofMinutes(2)); } - @Bean + @Bean("embeddingModel") public BedrockCohereEmbeddingModel cohereAiEmbedding(CohereEmbeddingBedrockApi cohereEmbeddingApi) { - return new BedrockCohereEmbeddingModel(cohereEmbeddingApi); + // custom model that uses the END truncation strategy, instead of the default + // NONE. + return new BedrockCohereEmbeddingModel(cohereEmbeddingApi, + BedrockCohereEmbeddingOptions.builder() + .withInputType(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT) + .withTruncate(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.Truncate.END) + .build()); + } + + @Bean("embeddingModelStartTruncate") + public BedrockCohereEmbeddingModel cohereAiEmbeddingStartTruncate( + CohereEmbeddingBedrockApi cohereEmbeddingApi) { + // custom model that uses the START truncation strategy, instead of the + // default NONE. + return new BedrockCohereEmbeddingModel(cohereEmbeddingApi, + BedrockCohereEmbeddingOptions.builder() + .withInputType(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT) + .withTruncate(CohereEmbeddingBedrockApi.CohereEmbeddingRequest.Truncate.START) + .build()); } }