Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package org.springframework.ai.azure.openai;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.EmbeddingItem;
Expand All @@ -17,7 +15,9 @@
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.util.Assert;

public class AzureOpenAiEmbeddingClient extends AbstractEmbeddingClient {
Expand Down Expand Up @@ -47,14 +47,6 @@ public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, String model,
this.metadataMode = metadataMode;
}

@Override
public List<Double> embed(String text) {
logger.debug("Retrieving embeddings");
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, new EmbeddingsOptions(List.of(text)));
logger.debug("Embeddings retrieved");
return extractEmbeddingsList(embeddings);
}

@Override
public List<Double> embed(Document document) {
logger.debug("Retrieving embeddings");
Expand All @@ -69,35 +61,20 @@ private List<Double> extractEmbeddingsList(Embeddings embeddings) {
}

@Override
public List<List<Double>> embed(List<String> texts) {
public EmbeddingResponse call(EmbeddingRequest request) {
logger.debug("Retrieving embeddings");
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, new EmbeddingsOptions(texts));
logger.debug("Embeddings retrieved");
return embeddings.getData().stream().map(emb -> emb.getEmbedding()).toList();
}

@Override
public EmbeddingResponse embedForResponse(List<String> texts) {
logger.debug("Retrieving embeddings");
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model, new EmbeddingsOptions(texts));
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(this.model,
new EmbeddingsOptions(request.getInstructions()));
logger.debug("Embeddings retrieved");
return generateEmbeddingResponse(embeddings);
}

private EmbeddingResponse generateEmbeddingResponse(Embeddings embeddings) {
List<Embedding> data = generateEmbeddingList(embeddings.getData());
Map<String, Object> metadata = generateMetadata(this.model, embeddings.getUsage());
EmbeddingResponseMetadata metadata = generateMetadata(this.model, embeddings.getUsage());
return new EmbeddingResponse(data, metadata);
}

private Map<String, Object> generateMetadata(String model, EmbeddingsUsage embeddingsUsage) {
Map<String, Object> metadata = new HashMap<>();
metadata.put("model", model);
metadata.put("prompt-tokens", embeddingsUsage.getPromptTokens());
metadata.put("total-tokens", embeddingsUsage.getTotalTokens());
return metadata;
}

private List<Embedding> generateEmbeddingList(List<EmbeddingItem> nativeData) {
List<Embedding> data = new ArrayList<>();
for (EmbeddingItem nativeDatum : nativeData) {
Expand All @@ -109,4 +86,12 @@ private List<Embedding> generateEmbeddingList(List<EmbeddingItem> nativeData) {
return data;
}

private EmbeddingResponseMetadata generateMetadata(String model, EmbeddingsUsage embeddingsUsage) {
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
metadata.put("model", model);
metadata.put("prompt-tokens", embeddingsUsage.getPromptTokens());
metadata.put("total-tokens", embeddingsUsage.getTotalTokens());
return metadata;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class AzureOpenAiEmbeddingClientIT {
void singleEmbedding() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
assertThat(embeddingResponse.getData()).hasSize(1);
assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty();
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
System.out.println(embeddingClient.dimensions());
assertThat(embeddingClient.dimensions()).isEqualTo(1536);
}
Expand All @@ -39,11 +39,11 @@ void batchEmbedding() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient
.embedForResponse(List.of("Hello World", "World is big and salvation is near"));
assertThat(embeddingResponse.getData()).hasSize(2);
assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty();
assertThat(embeddingResponse.getData().get(0).getIndex()).isEqualTo(0);
assertThat(embeddingResponse.getData().get(1).getEmbedding()).isNotEmpty();
assertThat(embeddingResponse.getData().get(1).getIndex()).isEqualTo(1);
assertThat(embeddingResponse.getResults()).hasSize(2);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);

assertThat(embeddingClient.dimensions()).isEqualTo(1536);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -70,29 +71,19 @@ public BedrockCohereEmbeddingClient withTruncate(CohereEmbeddingRequest.Truncate
return this;
}

@Override
public List<Double> embed(String text) {
return this.embed(List.of(text)).iterator().next();
}

@Override
public List<Double> embed(Document document) {
return embed(document.getContent());
}

@Override
public List<List<Double>> embed(List<String> texts) {
Assert.notEmpty(texts, "At least one text is required!");
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");

var request = new CohereEmbeddingRequest(texts, this.inputType, this.truncate);
CohereEmbeddingResponse response = this.embeddingApi.embedding(request);
return response.embeddings();
}

@Override
public EmbeddingResponse embedForResponse(List<String> texts) {
var apiRequest = new CohereEmbeddingRequest(request.getInstructions(), this.inputType, this.truncate);
CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest);
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = this.embed(texts)
List<Embedding> embeddings = apiResponse.embeddings()
.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingUtil;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -72,43 +74,32 @@ public BedrockTitanEmbeddingClient withInputType(InputType inputType) {
return this;
}

@Override
public List<Double> embed(String inputContent) {
return this.embed(List.of(inputContent)).iterator().next();
}

@Override
public List<Double> embed(Document document) {
return embed(document.getContent());
}

@Override
public EmbeddingResponse embedForResponse(List<String> texts) {
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = this.embed(texts)
.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);
}

@Override
public List<List<Double>> embed(List<String> inputContents) {
Assert.notEmpty(inputContents, "At least one text is required!");
if (inputContents.size() != 1) {
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
if (request.getInstructions().size() != 1) {
logger.warn(
"Titan Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
}

List<List<Double>> embeddingList = new ArrayList<>();
for (String inputContent : inputContents) {
var request = (this.inputType == InputType.IMAGE)
for (String inputContent : request.getInstructions()) {
var apiRequest = (this.inputType == InputType.IMAGE)
? new TitanEmbeddingRequest.Builder().withInputImage(inputContent).build()
: new TitanEmbeddingRequest.Builder().withInputText(inputContent).build();
TitanEmbeddingResponse response = this.embeddingApi.embedding(request);
TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest);
embeddingList.add(response.embedding());
}
return embeddingList;
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = embeddingList.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class BedrockCohereEmbeddingClientIT {
void singleEmbedding() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
assertThat(embeddingResponse.getData()).hasSize(1);
assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty();
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
}

Expand All @@ -40,11 +40,11 @@ void batchEmbedding() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient
.embedForResponse(List.of("Hello World", "World is big and salvation is near"));
assertThat(embeddingResponse.getData()).hasSize(2);
assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty();
assertThat(embeddingResponse.getData().get(0).getIndex()).isEqualTo(0);
assertThat(embeddingResponse.getData().get(1).getEmbedding()).isNotEmpty();
assertThat(embeddingResponse.getData().get(1).getIndex()).isEqualTo(1);
assertThat(embeddingResponse.getResults()).hasSize(2);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);

assertThat(embeddingClient.dimensions()).isEqualTo(1024);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.springframework.core.io.DefaultResourceLoader;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
Expand All @@ -32,8 +31,8 @@ class BedrockTitanEmbeddingClientIT {
void singleEmbedding() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
assertThat(embeddingResponse.getData()).hasSize(1);
assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty();
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
}

Expand All @@ -45,8 +44,8 @@ void imageEmbedding() throws IOException {

EmbeddingResponse embeddingResponse = embeddingClient
.embedForResponse(List.of(Base64.getEncoder().encodeToString(image)));
assertThat(embeddingResponse.getData()).hasSize(1);
assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty();
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,38 +79,27 @@ public OllamaEmbeddingClient withOptions(OllamaOptions options) {
return this;
}

@Override
public List<Double> embed(String text) {
return this.embed(List.of(text)).iterator().next();
}

@Override
public List<Double> embed(Document document) {
return embed(document.getContent());
}

@Override
public List<List<Double>> embed(List<String> texts) {
Assert.notEmpty(texts, "At least one text is required!");
if (texts.size() != 1) {
public EmbeddingResponse call(org.springframework.ai.embedding.EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
if (request.getInstructions().size() != 1) {
logger.warn(
"Ollama Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
}

List<List<Double>> embeddingList = new ArrayList<>();
for (String inputContent : texts) {
for (String inputContent : request.getInstructions()) {
OllamaApi.EmbeddingResponse response = this.ollamaApi
.embeddings(new EmbeddingRequest(this.model, inputContent, this.clientOptions));
embeddingList.add(response.embedding());
}
return embeddingList;
}

@Override
public EmbeddingResponse embedForResponse(List<String> texts) {
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = this.embed(texts)
.stream()
List<Embedding> embeddings = embeddingList.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ public static void beforeAll() throws IOException, InterruptedException {
void singleEmbedding() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
assertThat(embeddingResponse.getData()).hasSize(1);
assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty();
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingClient.dimensions()).isEqualTo(3200);
}

Expand Down
Loading