diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java index e322f52cb7d..35513cbde4a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java @@ -16,7 +16,9 @@ package org.springframework.ai.embedding; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.springframework.ai.document.ContentFormatter; import org.springframework.ai.document.Document; @@ -41,6 +43,12 @@ public class TokenCountBatchingStrategy implements BatchingStrategy { */ private static final int MAX_INPUT_TOKEN_COUNT = 8191; + /** + * The actual max input token count used will be the original max input minus the + * threshold value multiplied by the original input. + */ + private static final double DEFAULT_TOKEN_COUNT_THRESHOLD_FACTOR = 0.1; + private final TokenCountEstimator tokenCountEstimator; private final int maxInputTokenCount; @@ -50,27 +58,31 @@ public class TokenCountBatchingStrategy implements BatchingStrategy { private final MetadataMode metadataMode; public TokenCountBatchingStrategy() { - this(EncodingType.CL100K_BASE, MAX_INPUT_TOKEN_COUNT); + this(EncodingType.CL100K_BASE, MAX_INPUT_TOKEN_COUNT, DEFAULT_TOKEN_COUNT_THRESHOLD_FACTOR); } /** * @param encodingType {@link EncodingType} + * @param thresholdFactor the threshold factor to use on top of the max input token + * count * @param maxInputTokenCount upper limit for input tokens */ - public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount) { - this(encodingType, maxInputTokenCount, Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.NONE); + public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double thresholdFactor) { + this(encodingType, maxInputTokenCount, thresholdFactor, Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.NONE); } /** * @param encodingType {@link EncodingType} * @param maxInputTokenCount upper limit for input tokens + * @param thresholdFactor the threshold factor to use on top of the max input token + * count * @param contentFormatter {@link ContentFormatter} * @param metadataMode {@link MetadataMode} */ - public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, + public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double thresholdFactor, ContentFormatter contentFormatter, MetadataMode metadataMode) { this.tokenCountEstimator = new JTokkitTokenCountEstimator(encodingType); - this.maxInputTokenCount = (int) Math.round(maxInputTokenCount - (maxInputTokenCount * .1)); + this.maxInputTokenCount = (int) Math.round(maxInputTokenCount - (maxInputTokenCount * thresholdFactor)); this.contentFormater = contentFormatter; this.metadataMode = metadataMode; } @@ -80,6 +92,7 @@ public List> batch(List documents) { List> batches = new ArrayList<>(); int currentSize = 0; List currentBatch = new ArrayList<>(); + Map documentTokens = new HashMap<>(); for (Document document : documents) { int tokenCount = this.tokenCountEstimator @@ -88,6 +101,11 @@ public List> batch(List documents) { throw new IllegalArgumentException( "Tokens in a single document exceeds the maximum number of allowed input tokens"); } + documentTokens.put(document, tokenCount); + } + + for (Document document : documentTokens.keySet()) { + Integer tokenCount = documentTokens.get(document); if (currentSize + tokenCount > maxInputTokenCount) { batches.add(currentBatch); currentBatch.clear();