Skip to content

Commit d538e00

Browse files
tzolovmarkpollack
authored andcommitted
Replace the Embedding format from List<Double> to float[]
- Adjust all affected classes including the Document. - Update docs. Related to #405
1 parent 656fa8b commit d538e00

File tree

67 files changed

+442
-412
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+442
-412
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,21 @@
2929
import org.springframework.ai.embedding.EmbeddingRequest;
3030
import org.springframework.ai.embedding.EmbeddingResponse;
3131
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
32+
import org.springframework.ai.model.EmbeddingUtils;
3233
import org.springframework.util.Assert;
34+
import org.springframework.util.CollectionUtils;
3335

3436
import java.util.ArrayList;
3537
import java.util.List;
3638

39+
/**
40+
* Azure Open AI Embedding Model implementation.
41+
*
42+
* @author Mark Pollack
43+
* @author Christian Tzolov
44+
* @author Thomas Vitale
45+
* @since 1.0.0
46+
*/
3747
public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel {
3848

3949
private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingModel.class);
@@ -64,13 +74,17 @@ public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode me
6474
}
6575

6676
@Override
67-
public List<Double> embed(Document document) {
77+
public float[] embed(Document document) {
6878
logger.debug("Retrieving embeddings");
6979

7080
EmbeddingResponse response = this
7181
.call(new EmbeddingRequest(List.of(document.getFormattedContent(this.metadataMode)), null));
7282
logger.debug("Embeddings retrieved");
73-
return response.getResults().stream().map(embedding -> embedding.getOutput()).flatMap(List::stream).toList();
83+
84+
if (CollectionUtils.isEmpty(response.getResults())) {
85+
return new float[0];
86+
}
87+
return response.getResults().get(0).getOutput();
7488
}
7589

7690
@Override
@@ -108,8 +122,7 @@ private List<Embedding> generateEmbeddingList(List<EmbeddingItem> nativeData) {
108122
for (EmbeddingItem nativeDatum : nativeData) {
109123
List<Float> nativeDatumEmbedding = nativeDatum.getEmbedding();
110124
int nativeIndex = nativeDatum.getPromptIndex();
111-
Embedding embedding = new Embedding(nativeDatumEmbedding.stream().map(f -> f.doubleValue()).toList(),
112-
nativeIndex);
125+
Embedding embedding = new Embedding(EmbeddingUtils.toPrimitive(nativeDatumEmbedding), nativeIndex);
113126
data.add(embedding);
114127
}
115128
return data;

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingModel.java

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -66,33 +66,8 @@ public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedr
6666
this.defaultOptions = options;
6767
}
6868

69-
// /**
70-
// * Cohere Embedding API input types.
71-
// * @param inputType the input type to use.
72-
// * @return this client.
73-
// */
74-
// public BedrockCohereEmbeddingModel withInputType(CohereEmbeddingRequest.InputType
75-
// inputType) {
76-
// this.inputType = inputType;
77-
// return this;
78-
// }
79-
80-
// /**
81-
// * Specifies how the API handles inputs longer than the maximum token length. If you
82-
// specify LEFT or RIGHT, the
83-
// * model discards the input until the remaining input is exactly the maximum input
84-
// token length for the model.
85-
// * @param truncate the truncate option to use.
86-
// * @return this client.
87-
// */
88-
// public BedrockCohereEmbeddingModel withTruncate(CohereEmbeddingRequest.Truncate
89-
// truncate) {
90-
// this.truncate = truncate;
91-
// return this;
92-
// }
93-
9469
@Override
95-
public List<Double> embed(Document document) {
70+
public float[] embed(Document document) {
9671
return embed(document.getContent());
9772
}
9873

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ public enum Truncate {
183183
@JsonInclude(Include.NON_NULL)
184184
public record CohereEmbeddingResponse(
185185
@JsonProperty("id") String id,
186-
@JsonProperty("embeddings") List<List<Double>> embeddings,
186+
@JsonProperty("embeddings") List<float[]> embeddings,
187187
@JsonProperty("texts") List<String> texts,
188188
@JsonProperty("response_type") String responseType,
189189
// For future use: Currently bedrock doesn't return invocationMetrics for the cohere embedding model.

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public BedrockTitanEmbeddingModel withInputType(InputType inputType) {
7575
}
7676

7777
@Override
78-
public List<Double> embed(Document document) {
78+
public float[] embed(Document document) {
7979
return embed(document.getContent());
8080
}
8181

@@ -87,16 +87,13 @@ public EmbeddingResponse call(EmbeddingRequest request) {
8787
"Titan Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
8888
}
8989

90-
List<List<Double>> embeddingList = new ArrayList<>();
90+
List<Embedding> embeddings = new ArrayList<>();
91+
var indexCounter = new AtomicInteger(0);
9192
for (String inputContent : request.getInstructions()) {
9293
var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions());
9394
TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest);
94-
embeddingList.add(response.embedding());
95+
embeddings.add(new Embedding(response.embedding(), indexCounter.getAndIncrement()));
9596
}
96-
var indexCounter = new AtomicInteger(0);
97-
List<Embedding> embeddings = embeddingList.stream()
98-
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
99-
.toList();
10097
return new EmbeddingResponse(embeddings);
10198
}
10299

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public TitanEmbeddingRequest build() {
137137
*/
138138
@JsonInclude(Include.NON_NULL)
139139
public record TitanEmbeddingResponse(
140-
@JsonProperty("embedding") List<Double> embedding,
140+
@JsonProperty("embedding") float[] embedding,
141141
@JsonProperty("inputTextTokenCount") Integer inputTextTokenCount,
142142
@JsonProperty("message") Object message) {
143143
}

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode, M
105105
}
106106

107107
@Override
108-
public List<Double> embed(Document document) {
108+
public float[] embed(Document document) {
109109
Assert.notNull(document, "Document must not be null");
110110
return this.embed(document.getFormattedContent(this.metadataMode));
111111
}
@@ -137,7 +137,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
137137

138138
List<Embedding> embeddings = new ArrayList<>();
139139
for (int i = 0; i < apiEmbeddingResponse.vectors().size(); i++) {
140-
List<Double> vector = apiEmbeddingResponse.vectors().get(i);
140+
float[] vector = apiEmbeddingResponse.vectors().get(i);
141141
embeddings.add(new Embedding(vector, i));
142142
}
143143
return new EmbeddingResponse(embeddings, metadata);

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/api/MiniMaxApi.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ public EmbeddingRequest(List<String> texts, EmbeddingType type) {
865865
*/
866866
@JsonInclude(Include.NON_NULL)
867867
public record EmbeddingList(
868-
@JsonProperty("vectors") List<List<Double>> vectors,
868+
@JsonProperty("vectors") List<float[]> vectors,
869869
@JsonProperty("model") String model,
870870
@JsonProperty("total_tokens") Integer totalTokens) {
871871
}

models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ public void miniMaxChatStreamNonTransientError() {
157157
@Test
158158
public void miniMaxEmbeddingTransientError() {
159159

160-
EmbeddingList expectedEmbeddings = new EmbeddingList(List.of(List.of(9.9, 8.8)), "model", 10);
160+
EmbeddingList expectedEmbeddings = new EmbeddingList(List.of(new float[] { 9.9f, 8.8f }), "model", 10);
161161

162162
when(miniMaxApi.embeddings(isA(EmbeddingRequest.class)))
163163
.thenThrow(new TransientAiException("Transient Error 1"))
@@ -168,7 +168,7 @@ public void miniMaxEmbeddingTransientError() {
168168
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));
169169

170170
assertThat(result).isNotNull();
171-
assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8));
171+
assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f });
172172
assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
173173
assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
174174
}

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
116116
}
117117

118118
@Override
119-
public List<Double> embed(Document document) {
119+
public float[] embed(Document document) {
120120
Assert.notNull(document, "Document must not be null");
121121
return this.embed(document.getFormattedContent(this.metadataMode));
122122
}

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ public record Usage(
196196
public record Embedding(
197197
// @formatter:off
198198
@JsonProperty("index") Integer index,
199-
@JsonProperty("embedding") List<Double> embedding,
199+
@JsonProperty("embedding") float[] embedding,
200200
@JsonProperty("object") String object) {
201201
// @formatter:on
202202

@@ -207,7 +207,7 @@ public record Embedding(
207207
* @param embedding The embedding vector, which is a list of floats. The length of
208208
* vector depends on the model.
209209
*/
210-
public Embedding(Integer index, List<Double> embedding) {
210+
public Embedding(Integer index, float[] embedding) {
211211
this(index, embedding, "embedding");
212212
}
213213
}

0 commit comments

Comments
 (0)