1616
1717package org .springframework .ai .zhipuai ;
1818
19- import java .util .ArrayList ;
2019import java .util .List ;
21- import java .util .concurrent .atomic .AtomicInteger ;
2220
2321import io .micrometer .observation .ObservationRegistry ;
2422import org .slf4j .Logger ;
2523import org .slf4j .LoggerFactory ;
2624
2725import org .springframework .ai .chat .metadata .DefaultUsage ;
26+ import org .springframework .ai .chat .metadata .EmptyUsage ;
27+ import org .springframework .ai .chat .metadata .Usage ;
2828import org .springframework .ai .document .Document ;
2929import org .springframework .ai .document .MetadataMode ;
3030import org .springframework .ai .embedding .AbstractEmbeddingModel ;
4343import org .springframework .ai .zhipuai .api .ZhiPuApiConstants ;
4444import org .springframework .retry .support .RetryTemplate ;
4545import org .springframework .util .Assert ;
46+ import org .springframework .util .CollectionUtils ;
4647import org .springframework .util .StringUtils ;
4748
4849/**
4950 * ZhiPuAI Embedding Model implementation.
5051 *
5152 * @author Geng Rong
5253 * @author Soby Chacko
54+ * @author YuJie Wan
5355 * @since 1.0.0
5456 */
5557public class ZhiPuAiEmbeddingModel extends AbstractEmbeddingModel {
@@ -150,12 +152,9 @@ public float[] embed(Document document) {
150152 @ Override
151153 public EmbeddingResponse call (EmbeddingRequest request ) {
152154 Assert .notEmpty (request .getInstructions (), "At least one text is required!" );
153- if (request .getInstructions ().size () != 1 ) {
154- logger .warn (
155- "ZhiPu Embedding does not support batch embedding. Will make multiple API calls to embed(Document)" );
156- }
157155
158156 EmbeddingRequest embeddingRequest = buildEmbeddingRequest (request );
157+ var zhipuEmbeddingRequest = zhipuEmbeddingRequest (embeddingRequest );
159158
160159 var observationContext = EmbeddingModelObservationContext .builder ()
161160 .embeddingRequest (embeddingRequest )
@@ -166,47 +165,37 @@ public EmbeddingResponse call(EmbeddingRequest request) {
166165 .observation (this .observationConvention , DEFAULT_OBSERVATION_CONVENTION , () -> observationContext ,
167166 this .observationRegistry )
168167 .observe (() -> {
169- List <float []> embeddingList = new ArrayList <>();
170-
171- var totalUsage = new ZhiPuAiApi .Usage (0 , 0 , 0 );
172-
173- for (String inputContent : request .getInstructions ()) {
174- var apiRequest = createEmbeddingRequest (inputContent , embeddingRequest .getOptions ());
175-
176- ZhiPuAiApi .EmbeddingList <ZhiPuAiApi .Embedding > response = this .retryTemplate
177- .execute (ctx -> this .zhiPuAiApi .embeddings (apiRequest ).getBody ());
178- if (response == null || response .data () == null || response .data ().isEmpty ()) {
179- logger .warn ("No embeddings returned for input: {}" , inputContent );
180- embeddingList .add (new float [0 ]);
181- }
182- else {
183- int completionTokens = totalUsage .completionTokens () + response .usage ().completionTokens ();
184- int promptTokens = totalUsage .promptTokens () + response .usage ().promptTokens ();
185- int totalTokens = totalUsage .totalTokens () + response .usage ().totalTokens ();
186- totalUsage = new ZhiPuAiApi .Usage (completionTokens , promptTokens , totalTokens );
187- embeddingList .add (response .data ().get (0 ).embedding ());
188- }
189- }
168+ var embeddingResponse = this .retryTemplate
169+ .execute (ctx -> this .zhiPuAiApi .embeddings (zhipuEmbeddingRequest ));
190170
191- String model = (request .getOptions () != null && request .getOptions ().getModel () != null )
192- ? request .getOptions ().getModel () : "unknown" ;
171+ if (embeddingResponse == null || embeddingResponse .getBody () == null
172+ || CollectionUtils .isEmpty (embeddingResponse .getBody ().data ())) {
173+ logger .warn ("No embeddings returned for request: {}" , request );
174+ return new EmbeddingResponse (List .of ());
175+ }
193176
194- var metadata = new EmbeddingResponseMetadata (model , getDefaultUsage (totalUsage ));
177+ ZhiPuAiApi .Usage usage = embeddingResponse .getBody ().usage ();
178+ Usage usageResponse = usage != null ? getDefaultUsage (usage ) : new EmptyUsage ();
195179
196- var indexCounter = new AtomicInteger ( 0 );
180+ var metadata = new EmbeddingResponseMetadata ( embeddingResponse . getBody (). model (), usageResponse );
197181
198- List <Embedding > embeddings = embeddingList .stream ()
199- .map (e -> new Embedding (e , indexCounter .getAndIncrement ()))
182+ List <Embedding > embeddings = embeddingResponse .getBody ()
183+ .data ()
184+ .stream ()
185+ .map (e -> new Embedding (e .embedding (), e .index ()))
200186 .toList ();
201187
202- EmbeddingResponse embeddingResponse = new EmbeddingResponse (embeddings , metadata );
203-
204- observationContext .setResponse (embeddingResponse );
205-
206- return embeddingResponse ;
188+ EmbeddingResponse response = new EmbeddingResponse (embeddings , metadata );
189+ observationContext .setResponse (response );
190+ return response ;
207191 });
208192 }
209193
194+ private ZhiPuAiApi .EmbeddingRequest <List <String >> zhipuEmbeddingRequest (EmbeddingRequest embeddingRequest ) {
195+ return new ZhiPuAiApi .EmbeddingRequest <>(embeddingRequest .getInstructions (),
196+ embeddingRequest .getOptions ().getModel (), embeddingRequest .getOptions ().getDimensions ());
197+ }
198+
210199 private DefaultUsage getDefaultUsage (ZhiPuAiApi .Usage usage ) {
211200 return new DefaultUsage (usage .promptTokens (), usage .completionTokens (), usage .totalTokens (), usage );
212201 }
@@ -231,10 +220,6 @@ EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
231220 return new EmbeddingRequest (embeddingRequest .getInstructions (), requestOptions );
232221 }
233222
234- private ZhiPuAiApi .EmbeddingRequest <String > createEmbeddingRequest (String text , EmbeddingOptions requestOptions ) {
235- return new ZhiPuAiApi .EmbeddingRequest <>(text , requestOptions .getModel (), requestOptions .getDimensions ());
236- }
237-
238223 public void setObservationConvention (EmbeddingModelObservationConvention observationConvention ) {
239224 this .observationConvention = observationConvention ;
240225 }
0 commit comments