Skip to content

Commit e601f14

Browse files
ThomasVitalemarkpollack
authored andcommitted
Add model name and token usage to EmbeddingResponseMetadata
- Moves these from free-text key/value pairs to interface - Enables programmatic use for evaluation and observability Signed-off-by: Thomas Vitale <[email protected]>
1 parent 41eab27 commit e601f14

File tree

12 files changed

+269
-65
lines changed

12 files changed

+269
-65
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
*/
1616
package org.springframework.ai.azure.openai;
1717

18-
import java.util.ArrayList;
19-
import java.util.List;
20-
18+
import com.azure.ai.openai.OpenAIClient;
19+
import com.azure.ai.openai.models.EmbeddingItem;
20+
import com.azure.ai.openai.models.Embeddings;
21+
import com.azure.ai.openai.models.EmbeddingsOptions;
2122
import org.slf4j.Logger;
2223
import org.slf4j.LoggerFactory;
24+
import org.springframework.ai.azure.openai.metadata.AzureOpenAiEmbeddingUsage;
2325
import org.springframework.ai.document.Document;
2426
import org.springframework.ai.document.MetadataMode;
2527
import org.springframework.ai.embedding.AbstractEmbeddingModel;
@@ -29,11 +31,8 @@
2931
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
3032
import org.springframework.util.Assert;
3133

32-
import com.azure.ai.openai.OpenAIClient;
33-
import com.azure.ai.openai.models.EmbeddingItem;
34-
import com.azure.ai.openai.models.Embeddings;
35-
import com.azure.ai.openai.models.EmbeddingsOptions;
36-
import com.azure.ai.openai.models.EmbeddingsUsage;
34+
import java.util.ArrayList;
35+
import java.util.List;
3736

3837
public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel {
3938

@@ -99,7 +98,8 @@ EmbeddingsOptions toEmbeddingOptions(EmbeddingRequest embeddingRequest) {
9998

10099
private EmbeddingResponse generateEmbeddingResponse(Embeddings embeddings) {
101100
List<Embedding> data = generateEmbeddingList(embeddings.getData());
102-
EmbeddingResponseMetadata metadata = generateMetadata(embeddings.getUsage());
101+
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
102+
metadata.setUsage(AzureOpenAiEmbeddingUsage.from(embeddings.getUsage()));
103103
return new EmbeddingResponse(data, metadata);
104104
}
105105

@@ -115,14 +115,6 @@ private List<Embedding> generateEmbeddingList(List<EmbeddingItem> nativeData) {
115115
return data;
116116
}
117117

118-
private EmbeddingResponseMetadata generateMetadata(EmbeddingsUsage embeddingsUsage) {
119-
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
120-
// metadata.put("model", model);
121-
metadata.put("prompt-tokens", embeddingsUsage.getPromptTokens());
122-
metadata.put("total-tokens", embeddingsUsage.getTotalTokens());
123-
return metadata;
124-
}
125-
126118
public AzureOpenAiEmbeddingOptions getDefaultOptions() {
127119
return this.defaultOptions;
128120
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.azure.openai.metadata;
17+
18+
import com.azure.ai.openai.models.EmbeddingsUsage;
19+
import org.springframework.ai.chat.metadata.Usage;
20+
import org.springframework.util.Assert;
21+
22+
/**
23+
* {@link Usage} implementation for {@literal Microsoft Azure OpenAI Service} embedding.
24+
*
25+
* @author Thomas Vitale
26+
* @see EmbeddingsUsage
27+
*/
28+
public class AzureOpenAiEmbeddingUsage implements Usage {
29+
30+
public static AzureOpenAiEmbeddingUsage from(EmbeddingsUsage usage) {
31+
Assert.notNull(usage, "EmbeddingsUsage must not be null");
32+
return new AzureOpenAiEmbeddingUsage(usage);
33+
}
34+
35+
private final EmbeddingsUsage usage;
36+
37+
public AzureOpenAiEmbeddingUsage(EmbeddingsUsage usage) {
38+
Assert.notNull(usage, "EmbeddingsUsage must not be null");
39+
this.usage = usage;
40+
}
41+
42+
protected EmbeddingsUsage getUsage() {
43+
return this.usage;
44+
}
45+
46+
@Override
47+
public Long getPromptTokens() {
48+
return (long) getUsage().getPromptTokens();
49+
}
50+
51+
@Override
52+
public Long getGenerationTokens() {
53+
return 0L;
54+
}
55+
56+
@Override
57+
public Long getTotalTokens() {
58+
return (long) getUsage().getTotalTokens();
59+
}
60+
61+
@Override
62+
public String toString() {
63+
return getUsage().toString();
64+
}
65+
66+
}

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import org.springframework.util.Assert;
2323

2424
/**
25-
* {@link Usage} implementation for {@literal Microsoft Azure OpenAI Service}.
25+
* {@link Usage} implementation for {@literal Microsoft Azure OpenAI Service} chat.
2626
*
2727
* @author John Blum
2828
* @see com.azure.ai.openai.models.CompletionsUsage

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.springframework.ai.embedding.EmbeddingResponse;
2727
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
2828
import org.springframework.ai.minimax.api.MiniMaxApi;
29+
import org.springframework.ai.minimax.metadata.MiniMaxUsage;
2930
import org.springframework.ai.model.ModelOptionsUtils;
3031
import org.springframework.ai.retry.RetryUtils;
3132
import org.springframework.retry.support.RetryTemplate;
@@ -38,6 +39,7 @@
3839
* MiniMax Embedding Model implementation.
3940
*
4041
* @author Geng Rong
42+
* @author Thomas Vitale
4143
* @since 1.0.0 M1
4244
*/
4345
public class MiniMaxEmbeddingModel extends AbstractEmbeddingModel {
@@ -130,7 +132,8 @@ public EmbeddingResponse call(EmbeddingRequest request) {
130132
return new EmbeddingResponse(List.of());
131133
}
132134

133-
var metadata = generateResponseMetadata(apiEmbeddingResponse.model(), apiEmbeddingResponse.totalTokens());
135+
var metadata = new EmbeddingResponseMetadata(apiEmbeddingResponse.model(),
136+
MiniMaxUsage.from(new MiniMaxApi.Usage(0, 0, apiEmbeddingResponse.totalTokens())));
134137

135138
List<Embedding> embeddings = new ArrayList<>();
136139
for (int i = 0; i < apiEmbeddingResponse.vectors().size(); i++) {
@@ -141,11 +144,4 @@ public EmbeddingResponse call(EmbeddingRequest request) {
141144
});
142145
}
143146

144-
private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens) {
145-
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
146-
metadata.put("model", model);
147-
metadata.put("total-tokens", totalTokens);
148-
return metadata;
149-
}
150-
151147
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.minimax.metadata;
17+
18+
import org.springframework.ai.chat.metadata.Usage;
19+
import org.springframework.ai.minimax.api.MiniMaxApi;
20+
import org.springframework.util.Assert;
21+
22+
/**
23+
* {@link Usage} implementation for {@literal MiniMax}.
24+
*
25+
* @author Thomas Vitale
26+
*/
27+
public class MiniMaxUsage implements Usage {
28+
29+
public static MiniMaxUsage from(MiniMaxApi.Usage usage) {
30+
return new MiniMaxUsage(usage);
31+
}
32+
33+
private final MiniMaxApi.Usage usage;
34+
35+
protected MiniMaxUsage(MiniMaxApi.Usage usage) {
36+
Assert.notNull(usage, "MiniMax Usage must not be null");
37+
this.usage = usage;
38+
}
39+
40+
protected MiniMaxApi.Usage getUsage() {
41+
return this.usage;
42+
}
43+
44+
@Override
45+
public Long getPromptTokens() {
46+
return getUsage().promptTokens().longValue();
47+
}
48+
49+
@Override
50+
public Long getGenerationTokens() {
51+
return getUsage().completionTokens().longValue();
52+
}
53+
54+
@Override
55+
public Long getTotalTokens() {
56+
return getUsage().totalTokens().longValue();
57+
}
58+
59+
@Override
60+
public String toString() {
61+
return getUsage().toString();
62+
}
63+
64+
}

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@
2929
import org.springframework.ai.embedding.EmbeddingResponse;
3030
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
3131
import org.springframework.ai.mistralai.api.MistralAiApi;
32+
import org.springframework.ai.mistralai.metadata.MistralAiUsage;
3233
import org.springframework.ai.model.ModelOptionsUtils;
3334
import org.springframework.ai.retry.RetryUtils;
3435
import org.springframework.retry.support.RetryTemplate;
3536
import org.springframework.util.Assert;
3637

3738
/**
3839
* @author Ricken Bazolo
40+
* @author Thomas Vitale
3941
* @since 0.8.1
4042
*/
4143
public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
@@ -100,7 +102,8 @@ public EmbeddingResponse call(EmbeddingRequest request) {
100102
return new EmbeddingResponse(List.of());
101103
}
102104

103-
var metadata = generateResponseMetadata(apiEmbeddingResponse.model(), apiEmbeddingResponse.usage());
105+
var metadata = new EmbeddingResponseMetadata(apiEmbeddingResponse.model(),
106+
MistralAiUsage.from(apiEmbeddingResponse.usage()));
104107

105108
var embeddings = apiEmbeddingResponse.data()
106109
.stream()
@@ -118,12 +121,4 @@ public List<Double> embed(Document document) {
118121
return this.embed(document.getFormattedContent(this.metadataMode));
119122
}
120123

121-
private EmbeddingResponseMetadata generateResponseMetadata(String model, MistralAiApi.Usage usage) {
122-
var metadata = new EmbeddingResponseMetadata();
123-
metadata.put("model", model);
124-
metadata.put("prompt-tokens", usage.promptTokens());
125-
metadata.put("total-tokens", usage.totalTokens());
126-
return metadata;
127-
}
128-
129124
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import org.springframework.ai.model.ModelOptionsUtils;
3232
import org.springframework.ai.openai.api.OpenAiApi;
3333
import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList;
34-
import org.springframework.ai.openai.api.OpenAiApi.Usage;
34+
import org.springframework.ai.openai.metadata.OpenAiUsage;
3535
import org.springframework.ai.retry.RetryUtils;
3636
import org.springframework.retry.support.RetryTemplate;
3737
import org.springframework.util.Assert;
@@ -40,6 +40,7 @@
4040
* Open AI Embedding Model implementation.
4141
*
4242
* @author Christian Tzolov
43+
* @author Thomas Vitale
4344
*/
4445
public class OpenAiEmbeddingModel extends AbstractEmbeddingModel {
4546

@@ -134,7 +135,8 @@ public EmbeddingResponse call(EmbeddingRequest request) {
134135
return new EmbeddingResponse(List.of());
135136
}
136137

137-
var metadata = generateResponseMetadata(apiEmbeddingResponse.model(), apiEmbeddingResponse.usage());
138+
var metadata = new EmbeddingResponseMetadata(apiEmbeddingResponse.model(),
139+
OpenAiUsage.from(apiEmbeddingResponse.usage()));
138140

139141
List<Embedding> embeddings = apiEmbeddingResponse.data()
140142
.stream()
@@ -146,13 +148,4 @@ public EmbeddingResponse call(EmbeddingRequest request) {
146148
});
147149
}
148150

149-
private EmbeddingResponseMetadata generateResponseMetadata(String model, Usage usage) {
150-
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
151-
metadata.put("model", model);
152-
metadata.put("prompt-tokens", usage.promptTokens());
153-
metadata.put("completion-tokens", usage.completionTokens());
154-
metadata.put("total-tokens", usage.totalTokens());
155-
return metadata;
156-
}
157-
158151
}

models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanEmbeddingModel.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.springframework.ai.model.ModelOptionsUtils;
2929
import org.springframework.ai.qianfan.api.QianFanApi;
3030
import org.springframework.ai.qianfan.api.QianFanApi.EmbeddingList;
31+
import org.springframework.ai.qianfan.metadata.QianFanUsage;
3132
import org.springframework.ai.retry.RetryUtils;
3233
import org.springframework.retry.support.RetryTemplate;
3334
import org.springframework.util.Assert;
@@ -38,6 +39,7 @@
3839
* QianFan Embedding Client implementation.
3940
*
4041
* @author Geng Rong
42+
* @author Thomas Vitale
4143
* @since 1.0
4244
*/
4345
public class QianFanEmbeddingModel extends AbstractEmbeddingModel {
@@ -135,7 +137,8 @@ public EmbeddingResponse call(EmbeddingRequest request) {
135137
+ ", message:" + apiEmbeddingResponse.errorNsg());
136138
}
137139

138-
var metadata = generateResponseMetadata(apiEmbeddingResponse.model(), apiEmbeddingResponse.usage());
140+
var metadata = new EmbeddingResponseMetadata(apiEmbeddingResponse.model(),
141+
QianFanUsage.from(apiEmbeddingResponse.usage()));
139142

140143
List<Embedding> embeddings = apiEmbeddingResponse.data()
141144
.stream()
@@ -146,12 +149,4 @@ public EmbeddingResponse call(EmbeddingRequest request) {
146149
});
147150
}
148151

149-
private EmbeddingResponseMetadata generateResponseMetadata(String model, QianFanApi.Usage usage) {
150-
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
151-
metadata.put("model", model);
152-
metadata.put("prompt-tokens", usage.promptTokens());
153-
metadata.put("total-tokens", usage.totalTokens());
154-
return metadata;
155-
}
156-
157152
}

0 commit comments

Comments
 (0)