diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java index ca71e6fcc4d..146ce89d04e 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import io.micrometer.observation.Observation; @@ -28,12 +29,9 @@ import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.AbstractEmbeddingModel; -import org.springframework.ai.embedding.Embedding; -import org.springframework.ai.embedding.EmbeddingOptions; -import org.springframework.ai.embedding.EmbeddingRequest; -import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.*; import org.springframework.util.Assert; /** @@ -89,6 +87,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { List embeddings = new ArrayList<>(); var indexCounter = new AtomicInteger(0); + int tokenUsage = 0; for (String inputContent : request.getInstructions()) { var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions()); @@ -111,6 +110,10 @@ public EmbeddingResponse call(EmbeddingRequest request) { } embeddings.add(new Embedding(response.embedding(), indexCounter.getAndIncrement())); + + if (response.inputTextTokenCount() != null) { + tokenUsage += response.inputTextTokenCount(); + } } catch (Exception ex) { logger.error("Titan API embedding failed for input at index {}: {}", indexCounter.get(), @@ -120,7 +123,10 @@ public EmbeddingResponse call(EmbeddingRequest request) { } } - return new EmbeddingResponse(embeddings); + EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata("", + getDefaultUsage(tokenUsage)); + + return new EmbeddingResponse(embeddings, embeddingResponseMetadata); } private TitanEmbeddingRequest createTitanEmbeddingRequest(String inputContent, EmbeddingOptions requestOptions) { @@ -155,6 +161,10 @@ private String summarizeInput(String input) { return input.length() > 100 ? input.substring(0, 100) + "..." : input; } + private DefaultUsage getDefaultUsage(int tokens) { + return new DefaultUsage(tokens, 0); + } + public enum InputType { TEXT, IMAGE