diff --git a/spring-ai-openai/pom.xml b/spring-ai-openai/pom.xml index 51d68cd1417..93e85f0450f 100644 --- a/spring-ai-openai/pom.xml +++ b/spring-ai-openai/pom.xml @@ -66,6 +66,11 @@ spring-boot-starter-logging + + org.springframework.boot + spring-boot-starter-webflux + + org.springframework.experimental.ai diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingWebClient.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingWebClient.java new file mode 100644 index 00000000000..8c6fe7159c2 --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingWebClient.java @@ -0,0 +1,135 @@ +package org.springframework.ai.openai.embedding; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingUtil; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.util.Assert; +import org.springframework.web.reactive.function.client.WebClient; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +public class OpenAiEmbeddingWebClient implements EmbeddingClient { + + private static final Logger logger = LoggerFactory.getLogger(OpenAiEmbeddingWebClient.class); + + private final WebClient webClient; + + private final ObjectMapper objectMapper; + + private final String model; + + private final AtomicInteger embeddingDimensions = new AtomicInteger(-1); + + private final MetadataMode metadataMode; + + public OpenAiEmbeddingWebClient(String openAiApiToken) { + this("https://api.openai.com/", openAiApiToken); + } + + public OpenAiEmbeddingWebClient(String openAiEndpoint, String openAiApiToken) { + this(openAiEndpoint, openAiApiToken, "text-embedding-ada-002"); + } + + public OpenAiEmbeddingWebClient(String openAiEndpoint, String openAiApiToken, String model) { + this(openAiEndpoint, openAiApiToken, model, MetadataMode.EMBED); + } + + public OpenAiEmbeddingWebClient(String openAiEndpoint, String openAiApiToken, String model, + MetadataMode metadataMode) { + Assert.notNull(openAiEndpoint, "OpenAiEndpoint must not be null"); + Assert.notNull(model, "Model must not be null"); + Assert.notNull(metadataMode, "metadataMode must not be null"); + this.webClient = WebClient.builder().baseUrl(openAiEndpoint) + .defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + openAiApiToken).build(); + this.objectMapper = new ObjectMapper().setSerializationInclusion(JsonInclude.Include.NON_NULL) + .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES) + .configure(DeserializationFeature.ACCEPT_EMPTY_STRING_AS_NULL_OBJECT, true); + this.model = model; + this.metadataMode = metadataMode; + } + + @Override + public List embed(String text) { + return embedWithTexts(List.of(text)); + } + + private List embedWithTexts(List texts) { + OpenAiEmbeddingsRequest embeddingRequest = + new OpenAiEmbeddingsRequest.Builder().input(texts).model(this.model).build(); + return createEmbeddings(embeddingRequest).data().get(0).embedding(); + } + + public OpenAiEmbeddingsResponse createEmbeddings(OpenAiEmbeddingsRequest embeddingsRequest) { + logger.trace("EmbeddingsInput: {}", embeddingsRequest.getInput()); + + OpenAiEmbeddingsResponse openAiEmbeddingsResponse = this.webClient.post().uri("/v1/embeddings") + .bodyValue(objectMapper.convertValue(embeddingsRequest, JsonNode.class)).retrieve() + .bodyToMono(OpenAiEmbeddingsResponse.class).block(); + + logger.trace("EmbeddingsData: {}", openAiEmbeddingsResponse.data()); + + return openAiEmbeddingsResponse; + } + + public List embed(Document document) { + return embedWithTexts(List.of(document.getFormattedContent(this.metadataMode))); + } + + public List> embed(List texts) { + EmbeddingResponse embeddingResponse = embedForResponse(texts); + return embeddingResponse.getData().stream().map(Embedding::getEmbedding).toList(); + } + + @Override + public EmbeddingResponse embedForResponse(List texts) { + OpenAiEmbeddingsRequest embeddingsRequest = + new OpenAiEmbeddingsRequest.Builder().input(texts).model(this.model).build(); + return generateEmbeddingResponse(createEmbeddings(embeddingsRequest)); + } + + private EmbeddingResponse generateEmbeddingResponse(OpenAiEmbeddingsResponse openAiEmbeddingsResponse) { + List data = generateEmbeddingList(openAiEmbeddingsResponse.data()); + Map metadata = + generateMetadata(openAiEmbeddingsResponse.model(), openAiEmbeddingsResponse.usage()); + return new EmbeddingResponse(data, metadata); + } + + private List generateEmbeddingList(List nativeData) { + return nativeData.stream().map(data -> new Embedding(data.embedding(), data.index())) + .collect(Collectors.toList()); + } + + private Map generateMetadata(String model, OpenAiEmbeddingsResponse.Usage usage) { + Map metadata = new HashMap<>(); + metadata.put("model", model); + metadata.put("prompt-tokens", usage.promptTokens()); + metadata.put("completion-tokens", usage.completionTokens()); + metadata.put("total-tokens", usage.totalTokens()); + return metadata; + } + + @Override + public int dimensions() { + if (this.embeddingDimensions.get() < 0) { + this.embeddingDimensions.set(EmbeddingUtil.dimensions(this, this.model)); + } + return this.embeddingDimensions.get(); + } + +} diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingsRequest.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingsRequest.java new file mode 100644 index 00000000000..4a803653195 --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingsRequest.java @@ -0,0 +1,77 @@ +package org.springframework.ai.openai.embedding; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; + +import java.util.List; + +@JsonDeserialize(builder = OpenAiEmbeddingsRequest.Builder.class) +public class OpenAiEmbeddingsRequest { + + private final List input; + private final String model; + private final String encodingFormat; + private final String user; + + public static class Builder { + + @JsonProperty("input") + private List input; + + @JsonProperty("model") + private String model; + @JsonProperty("encoding_format") + private String encodingFormat; + + @JsonProperty("user") + private String user; + + public Builder input(List input) { + this.input = input; + return this; + } + + public Builder encodingFormat(String encodingFormat) { + this.encodingFormat = encodingFormat; + return this; + } + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder user(String user) { + this.user = user; + return this; + } + + public OpenAiEmbeddingsRequest build() { + return new OpenAiEmbeddingsRequest(this); + } + } + + private OpenAiEmbeddingsRequest(Builder builder) { + this.input = builder.input; + this.encodingFormat = builder.encodingFormat; + this.model = builder.model; + this.user = builder.user; + } + + public List getInput() { + return input; + } + + public String getEncodingFormat() { + return encodingFormat; + } + + public String getModel() { + return model; + } + + public String getUser() { + return user; + } + +} \ No newline at end of file diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingsResponse.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingsResponse.java new file mode 100644 index 00000000000..003e8988536 --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingsResponse.java @@ -0,0 +1,48 @@ +package org.springframework.ai.openai.embedding; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +public record OpenAiEmbeddingsResponse( + + @JsonProperty("data") + List data, + + @JsonProperty("usage") + Usage usage, + + @JsonProperty("model") + String model, + + @JsonProperty("object") + String object +) { + public record Data( + + @JsonProperty("index") + Integer index, + + @JsonProperty("embedding") + List embedding, + + @JsonProperty("object") + String object + ) { + } + + public record Usage( + + @JsonProperty("prompt_tokens") + long promptTokens, + + @JsonProperty("completion_tokens") + long completionTokens, + + @JsonProperty("total_tokens") + long totalTokens + + ) { + } + +} \ No newline at end of file diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index aa777940fd7..949c048c175 100644 --- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -1,11 +1,11 @@ package org.springframework.ai.openai; import com.theokanning.openai.service.OpenAiService; - import org.springframework.ai.client.AiClient; import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.openai.client.OpenAiClient; import org.springframework.ai.openai.embedding.OpenAiEmbeddingClient; +import org.springframework.ai.openai.embedding.OpenAiEmbeddingWebClient; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; @@ -17,13 +17,18 @@ public class OpenAiTestConfiguration { @Bean public OpenAiService theoOpenAiService() { + String apiKey = getApiKey(); + OpenAiService openAiService = new OpenAiService(apiKey, Duration.ofSeconds(60)); + return openAiService; + } + + private String getApiKey() { String apiKey = System.getenv("OPENAI_API_KEY"); if (!StringUtils.hasText(apiKey)) { throw new IllegalArgumentException( "You must provide an API key. Put it in an environment variable under the name OPENAI_API_KEY"); } - OpenAiService openAiService = new OpenAiService(apiKey, Duration.ofSeconds(60)); - return openAiService; + return apiKey; } @Bean @@ -38,4 +43,9 @@ public EmbeddingClient openAiEmbeddingClient(OpenAiService theoOpenAiService) { return new OpenAiEmbeddingClient(theoOpenAiService); } + @Bean + public EmbeddingClient openAiEmbeddingWebClient() { + return new OpenAiEmbeddingWebClient(getApiKey()); + } + } diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java index 368af4a5775..36855f03eb8 100644 --- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java @@ -15,6 +15,9 @@ class EmbeddingIT { @Autowired private OpenAiEmbeddingClient embeddingClient; + @Autowired + private OpenAiEmbeddingWebClient embeddingWebClient; + @Test void simpleEmbedding() { assertThat(embeddingClient).isNotNull(); @@ -31,4 +34,20 @@ void simpleEmbedding() { assertThat(embeddingClient.dimensions()).isEqualTo(1536); } + @Test + void simpleEmbeddingWebClient() { + assertThat(embeddingClient).isNotNull(); + + EmbeddingResponse embeddingResponse = embeddingWebClient.embedForResponse(List.of("Hello World")); + System.out.println(embeddingResponse); + assertThat(embeddingResponse.getData()).hasSize(1); + assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty(); + assertThat(embeddingResponse.getMetadata()).containsEntry("model", "text-embedding-ada-002-v2"); + assertThat(embeddingResponse.getMetadata()).containsEntry("completion-tokens", 0L); + assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 2L); + assertThat(embeddingResponse.getMetadata()).containsEntry("prompt-tokens", 2L); + + assertThat(embeddingClient.dimensions()).isEqualTo(1536); + } + }