Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions spring-ai-openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@
<artifactId>spring-boot-starter-logging</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-webflux</artifactId>
</dependency>

<!-- test dependencies -->
<dependency>
<groupId>org.springframework.experimental.ai</groupId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package org.springframework.ai.openai.embedding;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingUtil;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.client.WebClient;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

public class OpenAiEmbeddingWebClient implements EmbeddingClient {

private static final Logger logger = LoggerFactory.getLogger(OpenAiEmbeddingWebClient.class);

private final WebClient webClient;

private final ObjectMapper objectMapper;

private final String model;

private final AtomicInteger embeddingDimensions = new AtomicInteger(-1);

private final MetadataMode metadataMode;

public OpenAiEmbeddingWebClient(String openAiApiToken) {
this("https://api.openai.com/", openAiApiToken);
}

public OpenAiEmbeddingWebClient(String openAiEndpoint, String openAiApiToken) {
this(openAiEndpoint, openAiApiToken, "text-embedding-ada-002");
}

public OpenAiEmbeddingWebClient(String openAiEndpoint, String openAiApiToken, String model) {
this(openAiEndpoint, openAiApiToken, model, MetadataMode.EMBED);
}

public OpenAiEmbeddingWebClient(String openAiEndpoint, String openAiApiToken, String model,
MetadataMode metadataMode) {
Assert.notNull(openAiEndpoint, "OpenAiEndpoint must not be null");
Assert.notNull(model, "Model must not be null");
Assert.notNull(metadataMode, "metadataMode must not be null");
this.webClient = WebClient.builder().baseUrl(openAiEndpoint)
.defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + openAiApiToken).build();
this.objectMapper = new ObjectMapper().setSerializationInclusion(JsonInclude.Include.NON_NULL)
.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
.configure(DeserializationFeature.ACCEPT_EMPTY_STRING_AS_NULL_OBJECT, true);
this.model = model;
this.metadataMode = metadataMode;
}

@Override
public List<Double> embed(String text) {
return embedWithTexts(List.of(text));
}

private List<Double> embedWithTexts(List<String> texts) {
OpenAiEmbeddingsRequest embeddingRequest =
new OpenAiEmbeddingsRequest.Builder().input(texts).model(this.model).build();
return createEmbeddings(embeddingRequest).data().get(0).embedding();
}

public OpenAiEmbeddingsResponse createEmbeddings(OpenAiEmbeddingsRequest embeddingsRequest) {
logger.trace("EmbeddingsInput: {}", embeddingsRequest.getInput());

OpenAiEmbeddingsResponse openAiEmbeddingsResponse = this.webClient.post().uri("/v1/embeddings")
.bodyValue(objectMapper.convertValue(embeddingsRequest, JsonNode.class)).retrieve()
.bodyToMono(OpenAiEmbeddingsResponse.class).block();

logger.trace("EmbeddingsData: {}", openAiEmbeddingsResponse.data());

return openAiEmbeddingsResponse;
}

public List<Double> embed(Document document) {
return embedWithTexts(List.of(document.getFormattedContent(this.metadataMode)));
}

public List<List<Double>> embed(List<String> texts) {
EmbeddingResponse embeddingResponse = embedForResponse(texts);
return embeddingResponse.getData().stream().map(Embedding::getEmbedding).toList();
}

@Override
public EmbeddingResponse embedForResponse(List<String> texts) {
OpenAiEmbeddingsRequest embeddingsRequest =
new OpenAiEmbeddingsRequest.Builder().input(texts).model(this.model).build();
return generateEmbeddingResponse(createEmbeddings(embeddingsRequest));
}

private EmbeddingResponse generateEmbeddingResponse(OpenAiEmbeddingsResponse openAiEmbeddingsResponse) {
List<Embedding> data = generateEmbeddingList(openAiEmbeddingsResponse.data());
Map<String, Object> metadata =
generateMetadata(openAiEmbeddingsResponse.model(), openAiEmbeddingsResponse.usage());
return new EmbeddingResponse(data, metadata);
}

private List<Embedding> generateEmbeddingList(List<OpenAiEmbeddingsResponse.Data> nativeData) {
return nativeData.stream().map(data -> new Embedding(data.embedding(), data.index()))
.collect(Collectors.toList());
}

private Map<String, Object> generateMetadata(String model, OpenAiEmbeddingsResponse.Usage usage) {
Map<String, Object> metadata = new HashMap<>();
metadata.put("model", model);
metadata.put("prompt-tokens", usage.promptTokens());
metadata.put("completion-tokens", usage.completionTokens());
metadata.put("total-tokens", usage.totalTokens());
return metadata;
}

@Override
public int dimensions() {
if (this.embeddingDimensions.get() < 0) {
this.embeddingDimensions.set(EmbeddingUtil.dimensions(this, this.model));
}
return this.embeddingDimensions.get();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package org.springframework.ai.openai.embedding;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;

import java.util.List;

@JsonDeserialize(builder = OpenAiEmbeddingsRequest.Builder.class)
public class OpenAiEmbeddingsRequest {

private final List<String> input;
private final String model;
private final String encodingFormat;
private final String user;

public static class Builder {

@JsonProperty("input")
private List<String> input;

@JsonProperty("model")
private String model;
@JsonProperty("encoding_format")
private String encodingFormat;

@JsonProperty("user")
private String user;

public Builder input(List<String> input) {
this.input = input;
return this;
}

public Builder encodingFormat(String encodingFormat) {
this.encodingFormat = encodingFormat;
return this;
}

public Builder model(String model) {
this.model = model;
return this;
}

public Builder user(String user) {
this.user = user;
return this;
}

public OpenAiEmbeddingsRequest build() {
return new OpenAiEmbeddingsRequest(this);
}
}

private OpenAiEmbeddingsRequest(Builder builder) {
this.input = builder.input;
this.encodingFormat = builder.encodingFormat;
this.model = builder.model;
this.user = builder.user;
}

public List<String> getInput() {
return input;
}

public String getEncodingFormat() {
return encodingFormat;
}

public String getModel() {
return model;
}

public String getUser() {
return user;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package org.springframework.ai.openai.embedding;

import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.List;

public record OpenAiEmbeddingsResponse(

@JsonProperty("data")
List<Data> data,

@JsonProperty("usage")
Usage usage,

@JsonProperty("model")
String model,

@JsonProperty("object")
String object
) {
public record Data(

@JsonProperty("index")
Integer index,

@JsonProperty("embedding")
List<Double> embedding,

@JsonProperty("object")
String object
) {
}

public record Usage(

@JsonProperty("prompt_tokens")
long promptTokens,

@JsonProperty("completion_tokens")
long completionTokens,

@JsonProperty("total_tokens")
long totalTokens

) {
}

}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package org.springframework.ai.openai;

import com.theokanning.openai.service.OpenAiService;

import org.springframework.ai.client.AiClient;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.openai.client.OpenAiClient;
import org.springframework.ai.openai.embedding.OpenAiEmbeddingClient;
import org.springframework.ai.openai.embedding.OpenAiEmbeddingWebClient;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.util.StringUtils;
Expand All @@ -17,13 +17,18 @@ public class OpenAiTestConfiguration {

@Bean
public OpenAiService theoOpenAiService() {
String apiKey = getApiKey();
OpenAiService openAiService = new OpenAiService(apiKey, Duration.ofSeconds(60));
return openAiService;
}

private String getApiKey() {
String apiKey = System.getenv("OPENAI_API_KEY");
if (!StringUtils.hasText(apiKey)) {
throw new IllegalArgumentException(
"You must provide an API key. Put it in an environment variable under the name OPENAI_API_KEY");
}
OpenAiService openAiService = new OpenAiService(apiKey, Duration.ofSeconds(60));
return openAiService;
return apiKey;
}

@Bean
Expand All @@ -38,4 +43,9 @@ public EmbeddingClient openAiEmbeddingClient(OpenAiService theoOpenAiService) {
return new OpenAiEmbeddingClient(theoOpenAiService);
}

@Bean
public EmbeddingClient openAiEmbeddingWebClient() {
return new OpenAiEmbeddingWebClient(getApiKey());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class EmbeddingIT {
@Autowired
private OpenAiEmbeddingClient embeddingClient;

@Autowired
private OpenAiEmbeddingWebClient embeddingWebClient;

@Test
void simpleEmbedding() {
assertThat(embeddingClient).isNotNull();
Expand All @@ -31,4 +34,20 @@ void simpleEmbedding() {
assertThat(embeddingClient.dimensions()).isEqualTo(1536);
}

@Test
void simpleEmbeddingWebClient() {
assertThat(embeddingClient).isNotNull();

EmbeddingResponse embeddingResponse = embeddingWebClient.embedForResponse(List.of("Hello World"));
System.out.println(embeddingResponse);
assertThat(embeddingResponse.getData()).hasSize(1);
assertThat(embeddingResponse.getData().get(0).getEmbedding()).isNotEmpty();
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "text-embedding-ada-002-v2");
assertThat(embeddingResponse.getMetadata()).containsEntry("completion-tokens", 0L);
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 2L);
assertThat(embeddingResponse.getMetadata()).containsEntry("prompt-tokens", 2L);

assertThat(embeddingClient.dimensions()).isEqualTo(1536);
}

}