Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public Struct build() {
Struct.Builder textBuilder = Struct.newBuilder();
textBuilder.putFields("content", valueOf(this.content));
if (StringUtils.hasText(this.taskType)) {
textBuilder.putFields("taskType", valueOf(this.taskType));
textBuilder.putFields("task_type", valueOf(this.taskType));
}
if (StringUtils.hasText(this.title)) {
textBuilder.putFields("title", valueOf(this.title));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ public Builder from(VertexAiTextEmbeddingOptions fromOptions) {
if (fromOptions.getTaskType() != null) {
this.options.setTaskType(fromOptions.getTaskType());
}
if (fromOptions.getAutoTruncate() != null) {
this.options.setAutoTruncate(fromOptions.getAutoTruncate());
}
if (StringUtils.hasText(fromOptions.getTitle())) {
this.options.setTitle(fromOptions.getTitle());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,13 @@

import java.util.List;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
Expand All @@ -30,6 +37,7 @@
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;

import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest(classes = VertexAiTextEmbeddingModelIT.Config.class)
Expand Down Expand Up @@ -65,6 +73,116 @@ void defaultEmbedding(String modelName) {
assertThat(this.embeddingModel.dimensions()).isEqualTo(768);
}

// Fixing https://github.com/spring-projects/spring-ai/issues/2168
@Test
void testTaskTypeProperty() {
// Use text-embedding-005 model
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
.model("text-embedding-005")
.taskType(VertexAiTextEmbeddingOptions.TaskType.RETRIEVAL_DOCUMENT)
.build();

String text = "Test text for embedding";

// Generate embedding using Spring AI with RETRIEVAL_DOCUMENT task type
EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of(text), options));

assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotNull();

// Get the embedding result
float[] springAiEmbedding = embeddingResponse.getResults().get(0).getOutput();

// Now generate the same embedding using Google SDK directly with
// RETRIEVAL_DOCUMENT
float[] googleSdkDocumentEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_DOCUMENT");

// Also generate embedding using Google SDK with RETRIEVAL_QUERY (which is the
// default)
float[] googleSdkQueryEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_QUERY");

// Spring AI embedding should match with what gets generated by Google SDK with
// RETRIEVAL_DOCUMENT task type.
assertThat(springAiEmbedding)
.as("Spring AI embedding with RETRIEVAL_DOCUMENT should match Google SDK RETRIEVAL_DOCUMENT embedding")
.isEqualTo(googleSdkDocumentEmbedding);

// Spring AI embedding which uses RETRIEVAL_DOCUMENT task_type should not match
// with what gets generated by
// Google SDK with RETRIEVAL_QUERY task type.
assertThat(springAiEmbedding)
.as("Spring AI embedding with RETRIEVAL_DOCUMENT should NOT match Google SDK RETRIEVAL_QUERY embedding")
.isNotEqualTo(googleSdkQueryEmbedding);
}

// Fixing https://github.com/spring-projects/spring-ai/issues/2168
@Test
void testDefaultTaskTypeBehavior() {
// Test default behavior without explicitly setting task type
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
.model("text-embedding-005")
.build();

String text = "Test text for default embedding";

EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of(text), options));

assertThat(embeddingResponse.getResults()).hasSize(1);

float[] springAiDefaultEmbedding = embeddingResponse.getResults().get(0).getOutput();

// According to documentation, default should be RETRIEVAL_DOCUMENT
float[] googleSdkDocumentEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_DOCUMENT");

assertThat(springAiDefaultEmbedding)
.as("Default Spring AI embedding should match Google SDK RETRIEVAL_DOCUMENT embedding")
.isEqualTo(googleSdkDocumentEmbedding);
}

private float[] getEmbeddingUsingGoogleSdk(String text, String taskType) {
try {
String endpoint = String.format("%s-aiplatform.googleapis.com:443",
System.getenv("VERTEX_AI_GEMINI_LOCATION"));
String project = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID");

PredictionServiceSettings settings = PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();

EndpointName endpointName = EndpointName.ofProjectLocationPublisherModelName(project,
System.getenv("VERTEX_AI_GEMINI_LOCATION"), "google", "text-embedding-005");

try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {
PredictRequest.Builder request = PredictRequest.newBuilder().setEndpoint(endpointName.toString());

request.addInstances(Value.newBuilder()
.setStructValue(Struct.newBuilder()
.putFields("content", Value.newBuilder().setStringValue(text).build())
.putFields("task_type", Value.newBuilder().setStringValue(taskType).build())
.build())
.build());

var prediction = client.predict(request.build()).getPredictionsList().get(0);
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
Value values = embeddings.getStructValue().getFieldsOrThrow("values");

List<Float> floatList = values.getListValue()
.getValuesList()
.stream()
.map(Value::getNumberValue)
.map(Double::floatValue)
.collect(toList());

float[] floatArray = new float[floatList.size()];
for (int i = 0; i < floatList.size(); i++) {
floatArray[i] = floatList.get(i);
}
return floatArray;
}
}
catch (Exception e) {
throw new RuntimeException("Failed to get embedding from Google SDK", e);
}
}

@SpringBootConfiguration
static class Config {

Expand Down