diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java index 5e4462c038e..6e217f5e5a2 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClient.java @@ -1,9 +1,7 @@ package org.springframework.ai.azure.openai; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.EmbeddingItem; @@ -17,7 +15,9 @@ import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingClient; import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.util.Assert; public class AzureOpenAiEmbeddingClient extends AbstractEmbeddingClient { @@ -47,14 +47,6 @@ public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, String model, this.metadataMode = metadataMode; } - @Override - public List embed(String text) { - logger.debug("Retrieving embeddings"); - Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, new EmbeddingsOptions(List.of(text))); - logger.debug("Embeddings retrieved"); - return extractEmbeddingsList(embeddings); - } - @Override public List embed(Document document) { logger.debug("Retrieving embeddings"); @@ -69,35 +61,20 @@ private List extractEmbeddingsList(Embeddings embeddings) { } @Override - public List> embed(List texts) { + public EmbeddingResponse call(EmbeddingRequest request) { logger.debug("Retrieving embeddings"); - Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, new EmbeddingsOptions(texts)); - logger.debug("Embeddings retrieved"); - return embeddings.getData().stream().map(emb -> emb.getEmbedding()).toList(); - } - - @Override - public EmbeddingResponse embedForResponse(List texts) { - logger.debug("Retrieving embeddings"); - Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, new EmbeddingsOptions(texts)); + Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, + new EmbeddingsOptions(request.getInstructions())); logger.debug("Embeddings retrieved"); return generateEmbeddingResponse(embeddings); } private EmbeddingResponse generateEmbeddingResponse(Embeddings embeddings) { List data = generateEmbeddingList(embeddings.getData()); - Map metadata = generateMetadata(this.model, embeddings.getUsage()); + EmbeddingResponseMetadata metadata = generateMetadata(this.model, embeddings.getUsage()); return new EmbeddingResponse(data, metadata); } - private Map generateMetadata(String model, EmbeddingsUsage embeddingsUsage) { - Map metadata = new HashMap<>(); - metadata.put("model", model); - metadata.put("prompt-tokens", embeddingsUsage.getPromptTokens()); - metadata.put("total-tokens", embeddingsUsage.getTotalTokens()); - return metadata; - } - private List generateEmbeddingList(List nativeData) { List data = new ArrayList<>(); for (EmbeddingItem nativeDatum : nativeData) { @@ -109,4 +86,12 @@ private List generateEmbeddingList(List nativeData) { return data; } + private EmbeddingResponseMetadata generateMetadata(String model, EmbeddingsUsage embeddingsUsage) { + EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); + metadata.put("model", model); + metadata.put("prompt-tokens", embeddingsUsage.getPromptTokens()); + metadata.put("total-tokens", embeddingsUsage.getTotalTokens()); + return metadata; + } + } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClientIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClientIT.java index bda1421308f..f23e2e6abbb 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClientIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingClientIT.java @@ -28,8 +28,8 @@ class AzureOpenAiEmbeddingClientIT { void singleEmbedding() { assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); - assertThat(embeddingResponse.getData()).hasSize(1); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); System.out.println(embeddingClient.dimensions()); assertThat(embeddingClient.dimensions()).isEqualTo(1536); } @@ -39,11 +39,11 @@ void batchEmbedding() { assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient .embedForResponse(List.of("Hello World", "World is big and salvation is near")); - assertThat(embeddingResponse.getData()).hasSize(2); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(0).getIndex()).isEqualTo(0); - assertThat(embeddingResponse.getData().get(1).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(1).getIndex()).isEqualTo(1); + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); assertThat(embeddingClient.dimensions()).isEqualTo(1536); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingClient.java index f1f97d00e31..8eecc855acf 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingClient.java @@ -25,6 +25,7 @@ import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingClient; import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.util.Assert; @@ -70,29 +71,19 @@ public BedrockCohereEmbeddingClient withTruncate(CohereEmbeddingRequest.Truncate return this; } - @Override - public List embed(String text) { - return this.embed(List.of(text)).iterator().next(); - } - @Override public List embed(Document document) { return embed(document.getContent()); } @Override - public List> embed(List texts) { - Assert.notEmpty(texts, "At least one text is required!"); + public EmbeddingResponse call(EmbeddingRequest request) { + Assert.notEmpty(request.getInstructions(), "At least one text is required!"); - var request = new CohereEmbeddingRequest(texts, this.inputType, this.truncate); - CohereEmbeddingResponse response = this.embeddingApi.embedding(request); - return response.embeddings(); - } - - @Override - public EmbeddingResponse embedForResponse(List texts) { + var apiRequest = new CohereEmbeddingRequest(request.getInstructions(), this.inputType, this.truncate); + CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest); var indexCounter = new AtomicInteger(0); - List embeddings = this.embed(texts) + List embeddings = apiResponse.embeddings() .stream() .map(e -> new Embedding(e, indexCounter.getAndIncrement())) .toList(); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java index 50220c3d868..09bae908336 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java @@ -22,14 +22,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingClient; import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; -import org.springframework.ai.embedding.EmbeddingUtil; import org.springframework.util.Assert; /** @@ -72,50 +73,39 @@ public BedrockTitanEmbeddingClient withInputType(InputType inputType) { return this; } - @Override - public List embed(String inputContent) { - return this.embed(List.of(inputContent)).iterator().next(); - } - @Override public List embed(Document document) { return embed(document.getContent()); } @Override - public EmbeddingResponse embedForResponse(List texts) { - var indexCounter = new AtomicInteger(0); - List embeddings = this.embed(texts) - .stream() - .map(e -> new Embedding(e, indexCounter.getAndIncrement())) - .toList(); - return new EmbeddingResponse(embeddings); - } - - @Override - public List> embed(List inputContents) { - Assert.notEmpty(inputContents, "At least one text is required!"); - if (inputContents.size() != 1) { + public EmbeddingResponse call(EmbeddingRequest request) { + Assert.notEmpty(request.getInstructions(), "At least one text is required!"); + if (request.getInstructions().size() != 1) { logger.warn( "Titan Embedding does not support batch embedding. Will make multiple API calls to embed(Document)"); } List> embeddingList = new ArrayList<>(); - for (String inputContent : inputContents) { - var request = (this.inputType == InputType.IMAGE) + for (String inputContent : request.getInstructions()) { + var apiRequest = (this.inputType == InputType.IMAGE) ? new TitanEmbeddingRequest.Builder().withInputImage(inputContent).build() : new TitanEmbeddingRequest.Builder().withInputText(inputContent).build(); - TitanEmbeddingResponse response = this.embeddingApi.embedding(request); + TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest); embeddingList.add(response.embedding()); } - return embeddingList; + var indexCounter = new AtomicInteger(0); + List embeddings = embeddingList.stream() + .map(e -> new Embedding(e, indexCounter.getAndIncrement())) + .toList(); + return new EmbeddingResponse(embeddings); } @Override public int dimensions() { if (this.inputType == InputType.IMAGE) { if (this.embeddingDimensions.get() < 0) { - this.embeddingDimensions.set(EmbeddingUtil.dimensions(this, embeddingApi.getModelId(), + this.embeddingDimensions.set(dimensions(this, embeddingApi.getModelId(), // small base64 encoded image "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")); } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingClientIT.java index d3e7f3b5d40..79a57290bce 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingClientIT.java @@ -30,8 +30,8 @@ class BedrockCohereEmbeddingClientIT { void singleEmbedding() { assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); - assertThat(embeddingResponse.getData()).hasSize(1); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); assertThat(embeddingClient.dimensions()).isEqualTo(1024); } @@ -40,11 +40,11 @@ void batchEmbedding() { assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient .embedForResponse(List.of("Hello World", "World is big and salvation is near")); - assertThat(embeddingResponse.getData()).hasSize(2); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(0).getIndex()).isEqualTo(0); - assertThat(embeddingResponse.getData().get(1).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(1).getIndex()).isEqualTo(1); + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); assertThat(embeddingClient.dimensions()).isEqualTo(1024); } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java index 478c0128951..ba2d53723bf 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClientIT.java @@ -18,7 +18,6 @@ import org.springframework.core.io.DefaultResourceLoader; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; @SpringBootTest @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @@ -32,8 +31,8 @@ class BedrockTitanEmbeddingClientIT { void singleEmbedding() { assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); - assertThat(embeddingResponse.getData()).hasSize(1); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); assertThat(embeddingClient.dimensions()).isEqualTo(1024); } @@ -45,8 +44,8 @@ void imageEmbedding() throws IOException { EmbeddingResponse embeddingResponse = embeddingClient .embedForResponse(List.of(Base64.getEncoder().encodeToString(image))); - assertThat(embeddingResponse.getData()).hasSize(1); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); assertThat(embeddingClient.dimensions()).isEqualTo(1024); } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java index 05c34aacf4d..279dc811dae 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java @@ -79,38 +79,27 @@ public OllamaEmbeddingClient withOptions(OllamaOptions options) { return this; } - @Override - public List embed(String text) { - return this.embed(List.of(text)).iterator().next(); - } - @Override public List embed(Document document) { return embed(document.getContent()); } @Override - public List> embed(List texts) { - Assert.notEmpty(texts, "At least one text is required!"); - if (texts.size() != 1) { + public EmbeddingResponse call(org.springframework.ai.embedding.EmbeddingRequest request) { + Assert.notEmpty(request.getInstructions(), "At least one text is required!"); + if (request.getInstructions().size() != 1) { logger.warn( "Ollama Embedding does not support batch embedding. Will make multiple API calls to embed(Document)"); } List> embeddingList = new ArrayList<>(); - for (String inputContent : texts) { + for (String inputContent : request.getInstructions()) { OllamaApi.EmbeddingResponse response = this.ollamaApi .embeddings(new EmbeddingRequest(this.model, inputContent, this.clientOptions)); embeddingList.add(response.embedding()); } - return embeddingList; - } - - @Override - public EmbeddingResponse embedForResponse(List texts) { var indexCounter = new AtomicInteger(0); - List embeddings = this.embed(texts) - .stream() + List embeddings = embeddingList.stream() .map(e -> new Embedding(e, indexCounter.getAndIncrement())) .toList(); return new EmbeddingResponse(embeddings); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java index 8f066d05b95..50e8621cdd2 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingClientIT.java @@ -50,8 +50,8 @@ public static void beforeAll() throws IOException, InterruptedException { void singleEmbedding() { assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); - assertThat(embeddingResponse.getData()).hasSize(1); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); assertThat(embeddingClient.dimensions()).isEqualTo(3200); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java index 627c476902f..eff4690e58e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java @@ -17,9 +17,7 @@ package org.springframework.ai.openai; import java.time.Duration; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,10 +26,11 @@ import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingClient; import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; -import org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest; import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException; import org.springframework.ai.openai.api.OpenAiApi.Usage; import org.springframework.retry.support.RetryTemplate; @@ -84,53 +83,33 @@ public List embed(Document document) { } @Override - public List embed(String text) { - Assert.notNull(text, "Text must not be null"); - return this.embed(List.of(text)).iterator().next(); - } - - @Override - public List> embed(List texts) { - Assert.notNull(texts, "Texts must not be null"); - EmbeddingRequest> request = new EmbeddingRequest<>(texts, this.embeddingModelName); - return this.retryTemplate.execute(ctx -> { - EmbeddingList body = this.openAiApi.embeddings(request).getBody(); - if (body == null) { - logger.warn("No embeddings returned for request: {}", request); - return List.of(); - } - return body.data().stream().map(embedding -> embedding.embedding()).toList(); - }); - } - - @Override - public EmbeddingResponse embedForResponse(List texts) { - - Assert.notNull(texts, "Texts must not be null"); + public EmbeddingResponse call(EmbeddingRequest request) { return this.retryTemplate.execute(ctx -> { - EmbeddingRequest> request = new EmbeddingRequest<>(texts, this.embeddingModelName); + org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest> apiRequest = new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>( + request.getInstructions(), this.embeddingModelName); - EmbeddingList embeddingResponse = this.openAiApi.embeddings(request).getBody(); + EmbeddingList apiEmbeddingResponse = this.openAiApi.embeddings(apiRequest).getBody(); - if (embeddingResponse == null) { + if (apiEmbeddingResponse == null) { logger.warn("No embeddings returned for request: {}", request); - return new EmbeddingResponse(List.of(), Map.of()); + return new EmbeddingResponse(List.of()); } - Map metadata = generateMetadata(embeddingResponse.model(), embeddingResponse.usage()); + var metadata = generateMetadata(apiEmbeddingResponse.model(), apiEmbeddingResponse.usage()); - List embeddings = embeddingResponse.data() + List embeddings = apiEmbeddingResponse.data() .stream() .map(e -> new Embedding(e.embedding(), e.index())) .toList(); return new EmbeddingResponse(embeddings, metadata); + }); } - private Map generateMetadata(String model, Usage usage) { - Map metadata = new HashMap<>(); + private EmbeddingResponseMetadata generateMetadata(String model, Usage usage) { + EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); metadata.put("model", model); metadata.put("prompt-tokens", usage.promptTokens()); metadata.put("completion-tokens", usage.completionTokens()); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java index 7d2ca92f6ea..894c0df8385 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java @@ -36,8 +36,8 @@ void simpleEmbedding() { assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); - assertThat(embeddingResponse.getData()).hasSize(1); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getMetadata()).containsEntry("model", "text-embedding-ada-002"); assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 2); assertThat(embeddingResponse.getMetadata()).containsEntry("prompt-tokens", 2); diff --git a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java index 192951d1b76..02cd2c877fe 100644 --- a/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java +++ b/models/spring-ai-postgresml/src/main/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClient.java @@ -14,7 +14,10 @@ import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingClient; import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.beans.factory.InitializingBean; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.RowMapper; @@ -153,13 +156,20 @@ public List> embed(List texts) { @Override public EmbeddingResponse embedForResponse(List texts) { + return this.call(new EmbeddingRequest(texts, new EmbeddingOptions())); + } + + @Override + public EmbeddingResponse call(EmbeddingRequest request) { List data = new ArrayList<>(); - List> embed = this.embed(texts); + List> embed = this.embed(request.getInstructions()); for (int i = 0; i < embed.size(); i++) { data.add(new Embedding(embed.get(i), i)); } - return new EmbeddingResponse(data, + var metadata = new EmbeddingResponseMetadata( Map.of("transformer", this.transformer, "vector-type", this.vectorType.name(), "kwargs", this.kwargs)); + + return new EmbeddingResponse(data, metadata); } @Override diff --git a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClientIT.java b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClientIT.java index d89678b2704..18a26372d59 100644 --- a/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClientIT.java +++ b/models/spring-ai-postgresml/src/test/java/org/springframework/ai/postgresml/PostgresMlEmbeddingClientIT.java @@ -108,15 +108,15 @@ void embedForResponse(String vectorType) { EmbeddingResponse embeddingResponse = embeddingClient .embedForResponse(List.of("Hello World!", "Spring AI!", "LLM!")); assertThat(embeddingResponse).isNotNull(); - assertThat(embeddingResponse.getData()).hasSize(3); + assertThat(embeddingResponse.getResults()).hasSize(3); assertThat(embeddingResponse.getMetadata()).containsExactlyEntriesOf( Map.of("transformer", "distilbert-base-uncased", "vector-type", vectorType, "kwargs", "{}")); - assertThat(embeddingResponse.getData().get(0).getIndex()).isEqualTo(0); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).hasSize(768); - assertThat(embeddingResponse.getData().get(1).getIndex()).isEqualTo(1); - assertThat(embeddingResponse.getData().get(1).getEmbedding()).hasSize(768); - assertThat(embeddingResponse.getData().get(2).getIndex()).isEqualTo(2); - assertThat(embeddingResponse.getData().get(2).getEmbedding()).hasSize(768); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(768); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); + assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(768); + assertThat(embeddingResponse.getResults().get(2).getIndex()).isEqualTo(2); + assertThat(embeddingResponse.getResults().get(2).getOutput()).hasSize(768); // embeddingClient.dropPgmlExtension(); } diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingClient.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingClient.java index 087510f279f..3f302010004 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingClient.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingClient.java @@ -5,10 +5,12 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.modality.nlp.preprocess.Tokenizer; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; @@ -25,6 +27,8 @@ import org.springframework.ai.document.MetadataMode; import org.springframework.ai.embedding.AbstractEmbeddingClient; import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.beans.factory.InitializingBean; import org.springframework.core.io.DefaultResourceLoader; @@ -205,17 +209,26 @@ public EmbeddingResponse embedForResponse(List texts) { for (int i = 0; i < embed.size(); i++) { data.add(new Embedding(embed.get(i), i)); } - return new EmbeddingResponse(data, Map.of()); + return new EmbeddingResponse(data); } @Override public List> embed(List texts) { + return this.call(new EmbeddingRequest(texts, new EmbeddingOptions())) + .getResults() + .stream() + .map(e -> e.getOutput()) + .toList(); + } + + @Override + public EmbeddingResponse call(EmbeddingRequest request) { List> resultEmbeddings = new ArrayList<>(); try { - Encoding[] encodings = this.tokenizer.batchEncode(texts); + Encoding[] encodings = this.tokenizer.batchEncode(request.getInstructions()); long[][] input_ids0 = new long[encodings.length][]; long[][] attention_mask0 = new long[encodings.length][]; @@ -265,7 +278,9 @@ public List> embed(List texts) { throw new RuntimeException(ex); } - return resultEmbeddings; + var indexCounter = new AtomicInteger(0); + return new EmbeddingResponse( + resultEmbeddings.stream().map(e -> new Embedding(e, indexCounter.incrementAndGet())).toList()); } private Map removeUnknownModelInputs(Map modelInputs) { diff --git a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingClientTests.java b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingClientTests.java index 2c4f6743626..e895bb8f183 100644 --- a/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingClientTests.java +++ b/models/spring-ai-transformers/src/test/java/org/springframework/ai/transformers/TransformersEmbeddingClientTests.java @@ -76,16 +76,16 @@ void embedForResponse() throws Exception { TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient(); embeddingClient.afterPropertiesSet(); EmbeddingResponse embed = embeddingClient.embedForResponse(List.of("Hello world", "World is big")); - assertThat(embed.getData()).hasSize(2); + assertThat(embed.getResults()).hasSize(2); assertThat(embed.getMetadata()).isEmpty(); - assertThat(embed.getData().get(0).getEmbedding()).hasSize(384); - assertThat(DF.format(embed.getData().get(0).getEmbedding().get(0))).isEqualTo(DF.format(-0.19744634628295898)); - assertThat(DF.format(embed.getData().get(0).getEmbedding().get(383))).isEqualTo(DF.format(0.17298996448516846)); + assertThat(embed.getResults().get(0).getOutput()).hasSize(384); + assertThat(DF.format(embed.getResults().get(0).getOutput().get(0))).isEqualTo(DF.format(-0.19744634628295898)); + assertThat(DF.format(embed.getResults().get(0).getOutput().get(383))).isEqualTo(DF.format(0.17298996448516846)); - assertThat(embed.getData().get(1).getEmbedding()).hasSize(384); - assertThat(DF.format(embed.getData().get(1).getEmbedding().get(0))).isEqualTo(DF.format(0.4293745160102844)); - assertThat(DF.format(embed.getData().get(1).getEmbedding().get(383))).isEqualTo(DF.format(0.05501303821802139)); + assertThat(embed.getResults().get(1).getOutput()).hasSize(384); + assertThat(DF.format(embed.getResults().get(1).getOutput().get(0))).isEqualTo(DF.format(0.4293745160102844)); + assertThat(DF.format(embed.getResults().get(1).getOutput().get(383))).isEqualTo(DF.format(0.05501303821802139)); } @Test diff --git a/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiEmbeddingClient.java b/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiEmbeddingClient.java index 5acd5f12db2..583cc090df7 100644 --- a/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiEmbeddingClient.java +++ b/models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiEmbeddingClient.java @@ -17,12 +17,12 @@ package org.springframework.ai.vertex; import java.util.List; -import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingClient; import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.vertex.api.VertexAiApi; @@ -37,30 +37,20 @@ public VertexAiEmbeddingClient(VertexAiApi vertexAiApi) { this.vertexAiApi = vertexAiApi; } - @Override - public List embed(String text) { - return this.vertexAiApi.embedText(text).value(); - } - @Override public List embed(Document document) { return embed(document.getContent()); } @Override - public List> embed(List texts) { - List vertexEmbeddings = this.vertexAiApi.batchEmbedText(texts); - return vertexEmbeddings.stream().map(e -> e.value()).toList(); - } - - @Override - public EmbeddingResponse embedForResponse(List texts) { - List vertexEmbeddings = this.vertexAiApi.batchEmbedText(texts); + public EmbeddingResponse call(EmbeddingRequest request) { + List vertexEmbeddings = this.vertexAiApi.batchEmbedText(request.getInstructions()); AtomicInteger indexCounter = new AtomicInteger(0); List embeddings = vertexEmbeddings.stream() .map(vm -> new Embedding(vm.value(), indexCounter.getAndIncrement())) .toList(); - return new EmbeddingResponse(embeddings, Map.of()); + return new EmbeddingResponse(embeddings); + } } diff --git a/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClientIT.java b/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClientIT.java index faa70bfc28e..cf1d2aae5b0 100644 --- a/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClientIT.java +++ b/models/spring-ai-vertex-ai/src/test/java/org/springframework/ai/vertex/embedding/VertexAiEmbeddingClientIT.java @@ -26,8 +26,8 @@ class VertexAiEmbeddingClientIT { void simpleEmbedding() { assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); - assertThat(embeddingResponse.getData()).hasSize(1); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); assertThat(embeddingClient.dimensions()).isEqualTo(768); } @@ -36,11 +36,11 @@ void batchEmbedding() { assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient .embedForResponse(List.of("Hello World", "World is big and salvation is near")); - assertThat(embeddingResponse.getData()).hasSize(2); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(0).getIndex()).isEqualTo(0); - assertThat(embeddingResponse.getData().get(1).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(1).getIndex()).isEqualTo(1); + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); assertThat(embeddingClient.dimensions()).isEqualTo(768); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingClient.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingClient.java index 63ee8902143..61960d1cfca 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingClient.java @@ -16,7 +16,13 @@ package org.springframework.ai.embedding; +import java.io.IOException; +import java.util.Map; +import java.util.Properties; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import org.springframework.core.io.DefaultResourceLoader; /** * Abstract implementation of the {@link EmbeddingClient} interface that provides @@ -28,10 +34,49 @@ public abstract class AbstractEmbeddingClient implements EmbeddingClient { protected final AtomicInteger embeddingDimensions = new AtomicInteger(-1); + private static Map KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions(); + + /** + * Return the dimension of the requested embedding generative name. If the generative + * name is unknown uses the EmbeddingClient to perform a dummy EmbeddingClient#embed + * and count the response dimensions. + * @param embeddingClient Fall-back client to determine, empirically the dimensions. + * @param modelName Embedding generative name to retrieve the dimensions for. + * @param dummyContent Dummy content to use for the empirical dimension calculation. + * @return Returns the embedding dimensions for the modelName. + */ + public static int dimensions(EmbeddingClient embeddingClient, String modelName, String dummyContent) { + + if (KNOWN_EMBEDDING_DIMENSIONS.containsKey(modelName)) { + // Retrieve the dimension from a pre-configured file. + return KNOWN_EMBEDDING_DIMENSIONS.get(modelName); + } + else { + // Determine the dimensions empirically. + // Generate an embedding and count the dimension size; + return embeddingClient.embed(dummyContent).size(); + } + } + + private static Map loadKnownModelDimensions() { + try { + Properties properties = new Properties(); + properties.load(new DefaultResourceLoader() + .getResource("classpath:/embedding/embedding-model-dimensions.properties") + .getInputStream()); + return properties.entrySet() + .stream() + .collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString()))); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + @Override public int dimensions() { if (this.embeddingDimensions.get() < 0) { - this.embeddingDimensions.set(EmbeddingUtil.dimensions(this, "Test", "Hello World")); + this.embeddingDimensions.set(dimensions(this, "Test", "Hello World")); } return this.embeddingDimensions.get(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java index 01721d6bf3c..935f0071e85 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/Embedding.java @@ -3,15 +3,19 @@ import java.util.List; import java.util.Objects; +import org.springframework.ai.model.ModelResult; + /** * Represents a single embedding vector. */ -public class Embedding { +public class Embedding implements ModelResult> { private List embedding; private Integer index; + private EmbeddingResultMetadata metadata; + /** * Creates a new {@link Embedding} instance. * @param embedding the embedding vector values. @@ -25,7 +29,8 @@ public Embedding(List embedding, Integer index) { /** * @return Get the embedding vector values. */ - public List getEmbedding() { + @Override + public List getOutput() { return embedding; } @@ -36,6 +41,13 @@ public Integer getIndex() { return index; } + /** + * @return Get the metadata associated with the embedding. + */ + public EmbeddingResultMetadata getMetadata() { + return metadata; + } + @Override public boolean equals(Object o) { if (this == o) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java index fdc519ad667..92af4a716a0 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingClient.java @@ -1,20 +1,25 @@ package org.springframework.ai.embedding; import org.springframework.ai.document.Document; +import org.springframework.ai.model.ModelClient; +import org.springframework.util.Assert; import java.util.List; /** * EmbeddingClient is a generic interface for embedding clients. */ -public interface EmbeddingClient { +public interface EmbeddingClient extends ModelClient { /** * Embeds the given text into a vector. * @param text the text to embed. * @return the embedded vector. */ - List embed(String text); + default List embed(String text) { + Assert.notNull(text, "Text must not be null"); + return this.embed(List.of(text)).iterator().next(); + } /** * Embeds the given document's content into a vector. @@ -28,14 +33,24 @@ public interface EmbeddingClient { * @param texts list of texts to embed. * @return list of list of embedded vectors. */ - List> embed(List texts); + default List> embed(List texts) { + Assert.notNull(texts, "Texts must not be null"); + return this.call(new EmbeddingRequest(texts, new EmbeddingOptions())) + .getResults() + .stream() + .map(Embedding::getOutput) + .toList(); + } /** * Embeds a batch of texts into vectors and returns the {@link EmbeddingResponse}. * @param texts list of texts to embed. * @return the embedding response. */ - EmbeddingResponse embedForResponse(List texts); + default EmbeddingResponse embedForResponse(List texts) { + Assert.notNull(texts, "Texts must not be null"); + return this.call(new EmbeddingRequest(texts, new EmbeddingOptions())); + } /** * @return the number of dimensions of the embedded vectors. It is generative diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java new file mode 100644 index 00000000000..64d930e502b --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingOptions.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.embedding; + +import org.springframework.ai.model.ModelOptions; + +/** + * @author Christian Tzolov + */ +public class EmbeddingOptions implements ModelOptions { + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java new file mode 100644 index 00000000000..17cc478bcc3 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java @@ -0,0 +1,47 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.embedding; + +import java.util.List; + +import org.springframework.ai.model.ModelRequest; + +/** + * @author Christian Tzolov + */ +public class EmbeddingRequest implements ModelRequest> { + + private final List inputs; + + private final EmbeddingOptions options; + + public EmbeddingRequest(List inputs, EmbeddingOptions options) { + this.inputs = inputs; + this.options = options; + } + + @Override + public List getInstructions() { + return this.inputs; + } + + @Override + public EmbeddingOptions getOptions() { + return this.options; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java index a9fb69c283c..ef1ed6d7382 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponse.java @@ -1,55 +1,63 @@ package org.springframework.ai.embedding; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Objects; +import org.springframework.ai.model.ModelResponse; +import org.springframework.util.Assert; + /** * Embedding response object. */ -public class EmbeddingResponse { +public class EmbeddingResponse implements ModelResponse { /** * Embedding data. */ - private List data; + private List embeddings; /** * Embedding metadata. */ - private Map metadata = new HashMap<>(); + private EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); /** * Creates a new {@link EmbeddingResponse} instance with empty metadata. - * @param data the embedding data. + * @param embeddings the embedding data. */ - public EmbeddingResponse(List data) { - this(data, new HashMap<>()); + public EmbeddingResponse(List embeddings) { + this(embeddings, new EmbeddingResponseMetadata()); } /** * Creates a new {@link EmbeddingResponse} instance. - * @param data the embedding data. + * @param embeddings the embedding data. * @param metadata the embedding metadata. */ - public EmbeddingResponse(List data, Map metadata) { - this.data = data; + public EmbeddingResponse(List embeddings, EmbeddingResponseMetadata metadata) { + this.embeddings = embeddings; this.metadata = metadata; } /** - * @return Get the embedding data. + * @return Get the embedding metadata. */ - public List getData() { - return data; + public EmbeddingResponseMetadata getMetadata() { + return metadata; + } + + @Override + public Embedding getResult() { + Assert.notEmpty(embeddings, "No embedding data available."); + return embeddings.get(0); } /** - * @return Get the embedding metadata. + * @return Get the embedding data. */ - public Map getMetadata() { - return metadata; + @Override + public List getResults() { + return embeddings; } @Override @@ -59,17 +67,17 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; EmbeddingResponse that = (EmbeddingResponse) o; - return Objects.equals(data, that.data) && Objects.equals(metadata, that.metadata); + return Objects.equals(embeddings, that.embeddings) && Objects.equals(metadata, that.metadata); } @Override public int hashCode() { - return Objects.hash(data, metadata); + return Objects.hash(embeddings, metadata); } @Override public String toString() { - return "EmbeddingResult{" + "data=" + data + ", metadata=" + metadata + '}'; + return "EmbeddingResult{" + "data=" + embeddings + ", metadata=" + metadata + '}'; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java new file mode 100644 index 00000000000..e5f7b903a4d --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResponseMetadata.java @@ -0,0 +1,46 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.embedding; + +import java.util.HashMap; +import java.util.Map; + +import org.springframework.ai.model.ResponseMetadata; + +/** + * @author Christian Tzolov + */ +public class EmbeddingResponseMetadata extends HashMap implements ResponseMetadata { + + private static final long serialVersionUID = 1L; + + public EmbeddingResponseMetadata() { + } + + public EmbeddingResponseMetadata(int initialCapacity) { + super(initialCapacity); + } + + public EmbeddingResponseMetadata(int initialCapacity, float loadFactor) { + super(initialCapacity, loadFactor); + } + + public EmbeddingResponseMetadata(Map metadata) { + super(metadata); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java new file mode 100644 index 00000000000..04f2e2bb053 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingResultMetadata.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.embedding; + +import org.springframework.ai.model.ResultMetadata; + +/** + * @author Christian Tzolov + */ +public class EmbeddingResultMetadata implements ResultMetadata { + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingUtil.java b/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingUtil.java deleted file mode 100644 index 91d621b0caf..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingUtil.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2023-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.embedding; - -import java.io.IOException; -import java.util.Map; -import java.util.Properties; -import java.util.stream.Collectors; - -import org.springframework.core.io.DefaultResourceLoader; - -/** - * @author Christian Tzolov - */ -public class EmbeddingUtil { - - private static Map KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions(); - - /** - * Return the dimension of the requested embedding generative name. If the generative - * name is unknown uses the EmbeddingClient to perform a dummy EmbeddingClient#embed - * and count the response dimensions. - * @param embeddingClient Fall-back client to determine, empirically the dimensions. - * @param modelName Embedding generative name to retrieve the dimensions for. - * @param dummyContent Dummy content to use for the empirical dimension calculation. - * @return Returns the embedding dimensions for the modelName. - */ - public static int dimensions(EmbeddingClient embeddingClient, String modelName, String dummyContent) { - - if (KNOWN_EMBEDDING_DIMENSIONS.containsKey(modelName)) { - // Retrieve the dimension from a pre-configured file. - return KNOWN_EMBEDDING_DIMENSIONS.get(modelName); - } - else { - // Determine the dimensions empirically. - // Generate an embedding and count the dimension size; - return embeddingClient.embed(dummyContent).size(); - } - } - - private static Map loadKnownModelDimensions() { - try { - Properties properties = new Properties(); - properties.load(new DefaultResourceLoader() - .getResource("classpath:/embedding/embedding-model-dimensions.properties") - .getInputStream()); - return properties.entrySet() - .stream() - .collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString()))); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - -} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/embedding/EmbeddingUtilTests.java b/spring-ai-core/src/test/java/org/springframework/ai/embedding/AbstractEmbeddingClientTests.java similarity index 86% rename from spring-ai-core/src/test/java/org/springframework/ai/embedding/EmbeddingUtilTests.java rename to spring-ai-core/src/test/java/org/springframework/ai/embedding/AbstractEmbeddingClientTests.java index c8f88d0a1e6..202775b1cbd 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/embedding/EmbeddingUtilTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/embedding/AbstractEmbeddingClientTests.java @@ -18,7 +18,6 @@ import java.util.List; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; @@ -39,7 +38,7 @@ * @author Christian Tzolov */ @ExtendWith(MockitoExtension.class) -public class EmbeddingUtilTests { +public class AbstractEmbeddingClientTests { @Mock private EmbeddingClient embeddingClient; @@ -68,6 +67,11 @@ public List> embed(List texts) { public EmbeddingResponse embedForResponse(List texts) { throw new UnsupportedOperationException("Unimplemented method 'embedForResponse'"); } + + @Override + public EmbeddingResponse call(EmbeddingRequest request) { + throw new UnsupportedOperationException("Unimplemented method 'call'"); + } }; assertThat(dummy.dimensions()).isEqualTo(3); @@ -76,7 +80,7 @@ public EmbeddingResponse embedForResponse(List texts) { @ParameterizedTest @CsvFileSource(resources = "/embedding/embedding-model-dimensions.properties", numLinesToSkip = 1, delimiter = '=') public void testKnownEmbeddingModelDimensions(String model, String dimension) { - assertThat(EmbeddingUtil.dimensions(embeddingClient, model, "Hello world!")) + assertThat(AbstractEmbeddingClient.dimensions(embeddingClient, model, "Hello world!")) .isEqualTo(Integer.valueOf(dimension)); verify(embeddingClient, never()).embed(any(String.class)); verify(embeddingClient, never()).embed(any(Document.class)); @@ -85,7 +89,7 @@ public void testKnownEmbeddingModelDimensions(String model, String dimension) { @Test public void testUnknownModelDimension() { when(embeddingClient.embed(eq("Hello world!"))).thenReturn(List.of(0.1, 0.1, 0.1)); - assertThat(EmbeddingUtil.dimensions(embeddingClient, "unknown_model", "Hello world!")).isEqualTo(3); + assertThat(AbstractEmbeddingClient.dimensions(embeddingClient, "unknown_model", "Hello world!")).isEqualTo(3); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index 4184358ab7d..7fa78ab94b4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -113,11 +113,11 @@ void embedding() { EmbeddingResponse embeddingResponse = embeddingClient .embedForResponse(List.of("Hello World", "World is big and salvation is near")); - assertThat(embeddingResponse.getData()).hasSize(2); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(0).getIndex()).isEqualTo(0); - assertThat(embeddingResponse.getData().get(1).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(1).getIndex()).isEqualTo(1); + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); assertThat(embeddingClient.dimensions()).isEqualTo(1536); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java index 7fe72743a83..4aba1f8c84f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereEmbeddingAutoConfigurationIT.java @@ -57,8 +57,8 @@ public void singleEmbedding() { BedrockCohereEmbeddingClient embeddingClient = context.getBean(BedrockCohereEmbeddingClient.class); assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); - assertThat(embeddingResponse.getData()).hasSize(1); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); assertThat(embeddingClient.dimensions()).isEqualTo(1024); }); } @@ -72,11 +72,11 @@ public void batchEmbedding() { assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient .embedForResponse(List.of("Hello World", "World is big and salvation is near")); - assertThat(embeddingResponse.getData()).hasSize(2); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(0).getIndex()).isEqualTo(0); - assertThat(embeddingResponse.getData().get(1).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(1).getIndex()).isEqualTo(1); + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); assertThat(embeddingClient.dimensions()).isEqualTo(1024); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java index 59169990532..ff3da4740f7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanEmbeddingAutoConfigurationIT.java @@ -56,8 +56,8 @@ public void singleTextEmbedding() { BedrockTitanEmbeddingClient embeddingClient = context.getBean(BedrockTitanEmbeddingClient.class); assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); - assertThat(embeddingResponse.getData()).hasSize(1); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); assertThat(embeddingClient.dimensions()).isEqualTo(1024); }); } @@ -75,8 +75,8 @@ public void singleImageEmbedding() { EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of(base64Image)); - assertThat(embeddingResponse.getData()).hasSize(1); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); assertThat(embeddingClient.dimensions()).isEqualTo(1024); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java index 6cae39e9700..9f7d825c9c1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaEmbeddingAutoConfigurationIT.java @@ -70,8 +70,8 @@ public void singleTextEmbedding() { OllamaEmbeddingClient embeddingClient = context.getBean(OllamaEmbeddingClient.class); assertThat(embeddingClient).isNotNull(); EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World")); - assertThat(embeddingResponse.getData()).hasSize(1); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); assertThat(embeddingClient.dimensions()).isEqualTo(3200); }); } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java index 748bc822a5d..ff96466ef79 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java @@ -76,11 +76,11 @@ void embedding() { EmbeddingResponse embeddingResponse = embeddingClient .embedForResponse(List.of("Hello World", "World is big and salvation is near")); - assertThat(embeddingResponse.getData()).hasSize(2); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(0).getIndex()).isEqualTo(0); - assertThat(embeddingResponse.getData().get(1).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(1).getIndex()).isEqualTo(1); + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); assertThat(embeddingClient.dimensions()).isEqualTo(1536); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java index 429e4d385fc..dd0cbaa98c5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/VertexAiAutoConfigurationIT.java @@ -62,11 +62,11 @@ void embedding() { EmbeddingResponse embeddingResponse = embeddingClient .embedForResponse(List.of("Hello World", "World is big and salvation is near")); - assertThat(embeddingResponse.getData()).hasSize(2); - assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(0).getIndex()).isEqualTo(0); - assertThat(embeddingResponse.getData().get(1).getEmbedding()).isNotEmpty(); - assertThat(embeddingResponse.getData().get(1).getIndex()).isEqualTo(1); + assertThat(embeddingResponse.getResults()).hasSize(2); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0); + assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty(); + assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1); assertThat(embeddingClient.dimensions()).isEqualTo(768); });