|  | 
| 1 | 1 | /* | 
| 2 |  | - * Copyright 2023-2024 the original author or authors. | 
|  | 2 | + * Copyright 2023-2025 the original author or authors. | 
| 3 | 3 |  * | 
| 4 | 4 |  * Licensed under the Apache License, Version 2.0 (the "License"); | 
| 5 | 5 |  * you may not use this file except in compliance with the License. | 
|  | 
| 58 | 58 |  * | 
| 59 | 59 |  * @author Christian Tzolov | 
| 60 | 60 |  * @author Mark Pollack | 
|  | 61 | + * @author Rodrigo Malara | 
| 61 | 62 |  * @since 1.0.0 | 
| 62 | 63 |  */ | 
| 63 | 64 | public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel { | 
| @@ -128,37 +129,38 @@ public EmbeddingResponse call(EmbeddingRequest request) { | 
| 128 | 129 | 			.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, | 
| 129 | 130 | 					this.observationRegistry) | 
| 130 | 131 | 			.observe(() -> { | 
| 131 |  | -				PredictionServiceClient client = createPredictionServiceClient(); | 
|  | 132 | +				try (PredictionServiceClient client = createPredictionServiceClient()) { | 
| 132 | 133 | 
 | 
| 133 |  | -				EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); | 
|  | 134 | +					EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); | 
| 134 | 135 | 
 | 
| 135 |  | -				PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName, | 
| 136 |  | -						finalOptions); | 
|  | 136 | +					PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName, | 
|  | 137 | +							finalOptions); | 
| 137 | 138 | 
 | 
| 138 |  | -				PredictResponse embeddingResponse = this.retryTemplate | 
| 139 |  | -					.execute(context -> getPredictResponse(client, predictRequestBuilder)); | 
|  | 139 | +					PredictResponse embeddingResponse = this.retryTemplate | 
|  | 140 | +						.execute(context -> getPredictResponse(client, predictRequestBuilder)); | 
| 140 | 141 | 
 | 
| 141 |  | -				int index = 0; | 
| 142 |  | -				int totalTokenCount = 0; | 
| 143 |  | -				List<Embedding> embeddingList = new ArrayList<>(); | 
| 144 |  | -				for (Value prediction : embeddingResponse.getPredictionsList()) { | 
| 145 |  | -					Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings"); | 
| 146 |  | -					Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics"); | 
| 147 |  | -					Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count"); | 
| 148 |  | -					totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue(); | 
|  | 142 | +					int index = 0; | 
|  | 143 | +					int totalTokenCount = 0; | 
|  | 144 | +					List<Embedding> embeddingList = new ArrayList<>(); | 
|  | 145 | +					for (Value prediction : embeddingResponse.getPredictionsList()) { | 
|  | 146 | +						Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings"); | 
|  | 147 | +						Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics"); | 
|  | 148 | +						Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count"); | 
|  | 149 | +						totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue(); | 
| 149 | 150 | 
 | 
| 150 |  | -					Value values = embeddings.getStructValue().getFieldsOrThrow("values"); | 
|  | 151 | +						Value values = embeddings.getStructValue().getFieldsOrThrow("values"); | 
| 151 | 152 | 
 | 
| 152 |  | -					float[] vectorValues = VertexAiEmbeddingUtils.toVector(values); | 
|  | 153 | +						float[] vectorValues = VertexAiEmbeddingUtils.toVector(values); | 
| 153 | 154 | 
 | 
| 154 |  | -					embeddingList.add(new Embedding(vectorValues, index++)); | 
| 155 |  | -				} | 
| 156 |  | -				EmbeddingResponse response = new EmbeddingResponse(embeddingList, | 
| 157 |  | -						generateResponseMetadata(finalOptions.getModel(), totalTokenCount)); | 
|  | 155 | +						embeddingList.add(new Embedding(vectorValues, index++)); | 
|  | 156 | +					} | 
|  | 157 | +					EmbeddingResponse response = new EmbeddingResponse(embeddingList, | 
|  | 158 | +							generateResponseMetadata(finalOptions.getModel(), totalTokenCount)); | 
| 158 | 159 | 
 | 
| 159 |  | -				observationContext.setResponse(response); | 
|  | 160 | +					observationContext.setResponse(response); | 
| 160 | 161 | 
 | 
| 161 |  | -				return response; | 
|  | 162 | +					return response; | 
|  | 163 | +				} | 
| 162 | 164 | 			}); | 
| 163 | 165 | 	} | 
| 164 | 166 | 
 | 
|  | 
0 commit comments