diff --git a/spring-ai-openai/pom.xml b/spring-ai-openai/pom.xml
index 51d68cd1417..93e85f0450f 100644
--- a/spring-ai-openai/pom.xml
+++ b/spring-ai-openai/pom.xml
@@ -66,6 +66,11 @@
spring-boot-starter-logging
+
+ org.springframework.boot
+ spring-boot-starter-webflux
+
+
org.springframework.experimental.ai
diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingWebClient.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingWebClient.java
new file mode 100644
index 00000000000..8c6fe7159c2
--- /dev/null
+++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingWebClient.java
@@ -0,0 +1,135 @@
+package org.springframework.ai.openai.embedding;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.document.Document;
+import org.springframework.ai.document.MetadataMode;
+import org.springframework.ai.embedding.Embedding;
+import org.springframework.ai.embedding.EmbeddingClient;
+import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.embedding.EmbeddingUtil;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.util.Assert;
+import org.springframework.web.reactive.function.client.WebClient;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+
+public class OpenAiEmbeddingWebClient implements EmbeddingClient {
+
+ private static final Logger logger = LoggerFactory.getLogger(OpenAiEmbeddingWebClient.class);
+
+ private final WebClient webClient;
+
+ private final ObjectMapper objectMapper;
+
+ private final String model;
+
+ private final AtomicInteger embeddingDimensions = new AtomicInteger(-1);
+
+ private final MetadataMode metadataMode;
+
+ public OpenAiEmbeddingWebClient(String openAiApiToken) {
+ this("https://api.openai.com/", openAiApiToken);
+ }
+
+ public OpenAiEmbeddingWebClient(String openAiEndpoint, String openAiApiToken) {
+ this(openAiEndpoint, openAiApiToken, "text-embedding-ada-002");
+ }
+
+ public OpenAiEmbeddingWebClient(String openAiEndpoint, String openAiApiToken, String model) {
+ this(openAiEndpoint, openAiApiToken, model, MetadataMode.EMBED);
+ }
+
+ public OpenAiEmbeddingWebClient(String openAiEndpoint, String openAiApiToken, String model,
+ MetadataMode metadataMode) {
+ Assert.notNull(openAiEndpoint, "OpenAiEndpoint must not be null");
+ Assert.notNull(model, "Model must not be null");
+ Assert.notNull(metadataMode, "metadataMode must not be null");
+ this.webClient = WebClient.builder().baseUrl(openAiEndpoint)
+ .defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+ .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + openAiApiToken).build();
+ this.objectMapper = new ObjectMapper().setSerializationInclusion(JsonInclude.Include.NON_NULL)
+ .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
+ .configure(DeserializationFeature.ACCEPT_EMPTY_STRING_AS_NULL_OBJECT, true);
+ this.model = model;
+ this.metadataMode = metadataMode;
+ }
+
+ @Override
+ public List embed(String text) {
+ return embedWithTexts(List.of(text));
+ }
+
+ private List embedWithTexts(List texts) {
+ OpenAiEmbeddingsRequest embeddingRequest =
+ new OpenAiEmbeddingsRequest.Builder().input(texts).model(this.model).build();
+ return createEmbeddings(embeddingRequest).data().get(0).embedding();
+ }
+
+ public OpenAiEmbeddingsResponse createEmbeddings(OpenAiEmbeddingsRequest embeddingsRequest) {
+ logger.trace("EmbeddingsInput: {}", embeddingsRequest.getInput());
+
+ OpenAiEmbeddingsResponse openAiEmbeddingsResponse = this.webClient.post().uri("/v1/embeddings")
+ .bodyValue(objectMapper.convertValue(embeddingsRequest, JsonNode.class)).retrieve()
+ .bodyToMono(OpenAiEmbeddingsResponse.class).block();
+
+ logger.trace("EmbeddingsData: {}", openAiEmbeddingsResponse.data());
+
+ return openAiEmbeddingsResponse;
+ }
+
+ public List embed(Document document) {
+ return embedWithTexts(List.of(document.getFormattedContent(this.metadataMode)));
+ }
+
+ public List> embed(List texts) {
+ EmbeddingResponse embeddingResponse = embedForResponse(texts);
+ return embeddingResponse.getData().stream().map(Embedding::getEmbedding).toList();
+ }
+
+ @Override
+ public EmbeddingResponse embedForResponse(List texts) {
+ OpenAiEmbeddingsRequest embeddingsRequest =
+ new OpenAiEmbeddingsRequest.Builder().input(texts).model(this.model).build();
+ return generateEmbeddingResponse(createEmbeddings(embeddingsRequest));
+ }
+
+ private EmbeddingResponse generateEmbeddingResponse(OpenAiEmbeddingsResponse openAiEmbeddingsResponse) {
+ List data = generateEmbeddingList(openAiEmbeddingsResponse.data());
+ Map metadata =
+ generateMetadata(openAiEmbeddingsResponse.model(), openAiEmbeddingsResponse.usage());
+ return new EmbeddingResponse(data, metadata);
+ }
+
+ private List generateEmbeddingList(List nativeData) {
+ return nativeData.stream().map(data -> new Embedding(data.embedding(), data.index()))
+ .collect(Collectors.toList());
+ }
+
+ private Map generateMetadata(String model, OpenAiEmbeddingsResponse.Usage usage) {
+ Map metadata = new HashMap<>();
+ metadata.put("model", model);
+ metadata.put("prompt-tokens", usage.promptTokens());
+ metadata.put("completion-tokens", usage.completionTokens());
+ metadata.put("total-tokens", usage.totalTokens());
+ return metadata;
+ }
+
+ @Override
+ public int dimensions() {
+ if (this.embeddingDimensions.get() < 0) {
+ this.embeddingDimensions.set(EmbeddingUtil.dimensions(this, this.model));
+ }
+ return this.embeddingDimensions.get();
+ }
+
+}
diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingsRequest.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingsRequest.java
new file mode 100644
index 00000000000..4a803653195
--- /dev/null
+++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingsRequest.java
@@ -0,0 +1,77 @@
+package org.springframework.ai.openai.embedding;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
+
+import java.util.List;
+
+@JsonDeserialize(builder = OpenAiEmbeddingsRequest.Builder.class)
+public class OpenAiEmbeddingsRequest {
+
+ private final List input;
+ private final String model;
+ private final String encodingFormat;
+ private final String user;
+
+ public static class Builder {
+
+ @JsonProperty("input")
+ private List input;
+
+ @JsonProperty("model")
+ private String model;
+ @JsonProperty("encoding_format")
+ private String encodingFormat;
+
+ @JsonProperty("user")
+ private String user;
+
+ public Builder input(List input) {
+ this.input = input;
+ return this;
+ }
+
+ public Builder encodingFormat(String encodingFormat) {
+ this.encodingFormat = encodingFormat;
+ return this;
+ }
+
+ public Builder model(String model) {
+ this.model = model;
+ return this;
+ }
+
+ public Builder user(String user) {
+ this.user = user;
+ return this;
+ }
+
+ public OpenAiEmbeddingsRequest build() {
+ return new OpenAiEmbeddingsRequest(this);
+ }
+ }
+
+ private OpenAiEmbeddingsRequest(Builder builder) {
+ this.input = builder.input;
+ this.encodingFormat = builder.encodingFormat;
+ this.model = builder.model;
+ this.user = builder.user;
+ }
+
+ public List getInput() {
+ return input;
+ }
+
+ public String getEncodingFormat() {
+ return encodingFormat;
+ }
+
+ public String getModel() {
+ return model;
+ }
+
+ public String getUser() {
+ return user;
+ }
+
+}
\ No newline at end of file
diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingsResponse.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingsResponse.java
new file mode 100644
index 00000000000..003e8988536
--- /dev/null
+++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingsResponse.java
@@ -0,0 +1,48 @@
+package org.springframework.ai.openai.embedding;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import java.util.List;
+
+public record OpenAiEmbeddingsResponse(
+
+ @JsonProperty("data")
+ List data,
+
+ @JsonProperty("usage")
+ Usage usage,
+
+ @JsonProperty("model")
+ String model,
+
+ @JsonProperty("object")
+ String object
+) {
+ public record Data(
+
+ @JsonProperty("index")
+ Integer index,
+
+ @JsonProperty("embedding")
+ List embedding,
+
+ @JsonProperty("object")
+ String object
+ ) {
+ }
+
+ public record Usage(
+
+ @JsonProperty("prompt_tokens")
+ long promptTokens,
+
+ @JsonProperty("completion_tokens")
+ long completionTokens,
+
+ @JsonProperty("total_tokens")
+ long totalTokens
+
+ ) {
+ }
+
+}
\ No newline at end of file
diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java
index aa777940fd7..949c048c175 100644
--- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java
+++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java
@@ -1,11 +1,11 @@
package org.springframework.ai.openai;
import com.theokanning.openai.service.OpenAiService;
-
import org.springframework.ai.client.AiClient;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.openai.client.OpenAiClient;
import org.springframework.ai.openai.embedding.OpenAiEmbeddingClient;
+import org.springframework.ai.openai.embedding.OpenAiEmbeddingWebClient;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.util.StringUtils;
@@ -17,13 +17,18 @@ public class OpenAiTestConfiguration {
@Bean
public OpenAiService theoOpenAiService() {
+ String apiKey = getApiKey();
+ OpenAiService openAiService = new OpenAiService(apiKey, Duration.ofSeconds(60));
+ return openAiService;
+ }
+
+ private String getApiKey() {
String apiKey = System.getenv("OPENAI_API_KEY");
if (!StringUtils.hasText(apiKey)) {
throw new IllegalArgumentException(
"You must provide an API key. Put it in an environment variable under the name OPENAI_API_KEY");
}
- OpenAiService openAiService = new OpenAiService(apiKey, Duration.ofSeconds(60));
- return openAiService;
+ return apiKey;
}
@Bean
@@ -38,4 +43,9 @@ public EmbeddingClient openAiEmbeddingClient(OpenAiService theoOpenAiService) {
return new OpenAiEmbeddingClient(theoOpenAiService);
}
+ @Bean
+ public EmbeddingClient openAiEmbeddingWebClient() {
+ return new OpenAiEmbeddingWebClient(getApiKey());
+ }
+
}
diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java
index 368af4a5775..36855f03eb8 100644
--- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java
+++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java
@@ -15,6 +15,9 @@ class EmbeddingIT {
@Autowired
private OpenAiEmbeddingClient embeddingClient;
+ @Autowired
+ private OpenAiEmbeddingWebClient embeddingWebClient;
+
@Test
void simpleEmbedding() {
assertThat(embeddingClient).isNotNull();
@@ -31,4 +34,20 @@ void simpleEmbedding() {
assertThat(embeddingClient.dimensions()).isEqualTo(1536);
}
+ @Test
+ void simpleEmbeddingWebClient() {
+ assertThat(embeddingClient).isNotNull();
+
+ EmbeddingResponse embeddingResponse = embeddingWebClient.embedForResponse(List.of("Hello World"));
+ System.out.println(embeddingResponse);
+ assertThat(embeddingResponse.getData()).hasSize(1);
+ assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty();
+ assertThat(embeddingResponse.getMetadata()).containsEntry("model", "text-embedding-ada-002-v2");
+ assertThat(embeddingResponse.getMetadata()).containsEntry("completion-tokens", 0L);
+ assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 2L);
+ assertThat(embeddingResponse.getMetadata()).containsEntry("prompt-tokens", 2L);
+
+ assertThat(embeddingClient.dimensions()).isEqualTo(1536);
+ }
+
}