Skip to content

Commit f11fb87

Browse files
committed
optimize ZhiPu Embedding to support batch embedding.
- support batch embedding - Make test adjustments based on the official demo Signed-off-by: YuJie Wan <[email protected]>
1 parent eda3c74 commit f11fb87

File tree

3 files changed

+38
-46
lines changed

3 files changed

+38
-46
lines changed

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java

Lines changed: 27 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616

1717
package org.springframework.ai.zhipuai;
1818

19-
import java.util.ArrayList;
2019
import java.util.List;
21-
import java.util.concurrent.atomic.AtomicInteger;
2220

2321
import io.micrometer.observation.ObservationRegistry;
2422
import org.slf4j.Logger;
2523
import org.slf4j.LoggerFactory;
2624

2725
import org.springframework.ai.chat.metadata.DefaultUsage;
26+
import org.springframework.ai.chat.metadata.EmptyUsage;
27+
import org.springframework.ai.chat.metadata.Usage;
2828
import org.springframework.ai.document.Document;
2929
import org.springframework.ai.document.MetadataMode;
3030
import org.springframework.ai.embedding.AbstractEmbeddingModel;
@@ -43,13 +43,15 @@
4343
import org.springframework.ai.zhipuai.api.ZhiPuApiConstants;
4444
import org.springframework.retry.support.RetryTemplate;
4545
import org.springframework.util.Assert;
46+
import org.springframework.util.CollectionUtils;
4647
import org.springframework.util.StringUtils;
4748

4849
/**
4950
* ZhiPuAI Embedding Model implementation.
5051
*
5152
* @author Geng Rong
5253
* @author Soby Chacko
54+
* @author YuJie Wan
5355
* @since 1.0.0
5456
*/
5557
public class ZhiPuAiEmbeddingModel extends AbstractEmbeddingModel {
@@ -150,12 +152,9 @@ public float[] embed(Document document) {
150152
@Override
151153
public EmbeddingResponse call(EmbeddingRequest request) {
152154
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
153-
if (request.getInstructions().size() != 1) {
154-
logger.warn(
155-
"ZhiPu Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
156-
}
157155

158156
EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request);
157+
var zhipuEmbeddingRequest = zhipuEmbeddingRequest(embeddingRequest);
159158

160159
var observationContext = EmbeddingModelObservationContext.builder()
161160
.embeddingRequest(embeddingRequest)
@@ -166,47 +165,37 @@ public EmbeddingResponse call(EmbeddingRequest request) {
166165
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
167166
this.observationRegistry)
168167
.observe(() -> {
169-
List<float[]> embeddingList = new ArrayList<>();
170-
171-
var totalUsage = new ZhiPuAiApi.Usage(0, 0, 0);
172-
173-
for (String inputContent : request.getInstructions()) {
174-
var apiRequest = createEmbeddingRequest(inputContent, embeddingRequest.getOptions());
175-
176-
ZhiPuAiApi.EmbeddingList<ZhiPuAiApi.Embedding> response = this.retryTemplate
177-
.execute(ctx -> this.zhiPuAiApi.embeddings(apiRequest).getBody());
178-
if (response == null || response.data() == null || response.data().isEmpty()) {
179-
logger.warn("No embeddings returned for input: {}", inputContent);
180-
embeddingList.add(new float[0]);
181-
}
182-
else {
183-
int completionTokens = totalUsage.completionTokens() + response.usage().completionTokens();
184-
int promptTokens = totalUsage.promptTokens() + response.usage().promptTokens();
185-
int totalTokens = totalUsage.totalTokens() + response.usage().totalTokens();
186-
totalUsage = new ZhiPuAiApi.Usage(completionTokens, promptTokens, totalTokens);
187-
embeddingList.add(response.data().get(0).embedding());
188-
}
189-
}
168+
var embeddingResponse = this.retryTemplate
169+
.execute(ctx -> this.zhiPuAiApi.embeddings(zhipuEmbeddingRequest));
190170

191-
String model = (request.getOptions() != null && request.getOptions().getModel() != null)
192-
? request.getOptions().getModel() : "unknown";
171+
if (embeddingResponse == null || embeddingResponse.getBody() == null
172+
|| CollectionUtils.isEmpty(embeddingResponse.getBody().data())) {
173+
logger.warn("No embeddings returned for request: {}", request);
174+
return new EmbeddingResponse(List.of());
175+
}
193176

194-
var metadata = new EmbeddingResponseMetadata(model, getDefaultUsage(totalUsage));
177+
ZhiPuAiApi.Usage usage = embeddingResponse.getBody().usage();
178+
Usage usageResponse = usage != null ? getDefaultUsage(usage) : new EmptyUsage();
195179

196-
var indexCounter = new AtomicInteger(0);
180+
var metadata = new EmbeddingResponseMetadata(embeddingResponse.getBody().model(), usageResponse);
197181

198-
List<Embedding> embeddings = embeddingList.stream()
199-
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
182+
List<Embedding> embeddings = embeddingResponse.getBody()
183+
.data()
184+
.stream()
185+
.map(e -> new Embedding(e.embedding(), e.index()))
200186
.toList();
201187

202-
EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, metadata);
203-
204-
observationContext.setResponse(embeddingResponse);
205-
206-
return embeddingResponse;
188+
EmbeddingResponse response = new EmbeddingResponse(embeddings, metadata);
189+
observationContext.setResponse(response);
190+
return response;
207191
});
208192
}
209193

194+
private ZhiPuAiApi.EmbeddingRequest<List<String>> zhipuEmbeddingRequest(EmbeddingRequest embeddingRequest) {
195+
return new ZhiPuAiApi.EmbeddingRequest<>(embeddingRequest.getInstructions(),
196+
embeddingRequest.getOptions().getModel(), embeddingRequest.getOptions().getDimensions());
197+
}
198+
210199
private DefaultUsage getDefaultUsage(ZhiPuAiApi.Usage usage) {
211200
return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage);
212201
}
@@ -231,10 +220,6 @@ EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
231220
return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions);
232221
}
233222

234-
private ZhiPuAiApi.EmbeddingRequest<String> createEmbeddingRequest(String text, EmbeddingOptions requestOptions) {
235-
return new ZhiPuAiApi.EmbeddingRequest<>(text, requestOptions.getModel(), requestOptions.getDimensions());
236-
}
237-
238223
public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
239224
this.observationConvention = observationConvention;
240225
}

models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ public void zhiPuAiChatStreamNonTransientError() {
159159
public void zhiPuAiEmbeddingTransientError() {
160160

161161
EmbeddingList<Embedding> expectedEmbeddings = new EmbeddingList<>("list",
162-
List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new ZhiPuAiApi.Usage(10, 10, 10));
162+
List.of(new Embedding(0, new float[] { 9.9f, 8.8f }), new Embedding(0, new float[] { 9.9f, 8.8f })),
163+
"model", new ZhiPuAiApi.Usage(10, 10, 10));
163164

164165
given(this.zhiPuAiApi.embeddings(isA(EmbeddingRequest.class)))
165166
.willThrow(new TransientAiException("Transient Error 1"))
@@ -169,9 +170,11 @@ public void zhiPuAiEmbeddingTransientError() {
169170
var result = this.embeddingModel
170171
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), options));
171172

173+
assertThat(result.getResults().size()).isEqualTo(2);
172174
assertThat(result).isNotNull();
175+
// choose the first result
173176
assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f });
174-
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(0);
177+
assertThat(this.retryListener.onSuccessRetryCount).isEqualTo(2);
175178
assertThat(this.retryListener.onErrorRetryCount).isEqualTo(2);
176179
}
177180

models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/embedding/EmbeddingIT.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,20 @@ void embeddingV3WithCustomDimension() {
8585
void batchEmbedding() {
8686
assertThat(this.embeddingModel).isNotNull();
8787

88-
EmbeddingResponse embeddingResponse = this.embeddingModel.embedForResponse(List.of("Hello World", "HI"));
88+
EmbeddingResponse embeddingResponse = this.embeddingModel
89+
.embedForResponse(List.of("Hello world", "How are you?", "How is the weather today?"));
8990

90-
assertThat(embeddingResponse.getResults()).hasSize(2);
91+
assertThat(embeddingResponse.getResults()).hasSize(3);
9192

9293
assertThat(embeddingResponse.getResults().get(0)).isNotNull();
9394
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);
9495

9596
assertThat(embeddingResponse.getResults().get(1)).isNotNull();
9697
assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1024);
9798

99+
assertThat(embeddingResponse.getResults().get(2)).isNotNull();
100+
assertThat(embeddingResponse.getResults().get(2).getOutput()).hasSize(1024);
101+
98102
assertThat(this.embeddingModel.dimensions()).isEqualTo(1024);
99103
}
100104

0 commit comments

Comments
 (0)