diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java index 91a0cda70f2..62a2734b329 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java @@ -29,8 +29,8 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.watsonx.api.WatsonxAiApi; -import org.springframework.ai.watsonx.api.WatsonxAiRequest; -import org.springframework.ai.watsonx.api.WatsonxAiResponse; +import org.springframework.ai.watsonx.api.WatsonxAiChatRequest; +import org.springframework.ai.watsonx.api.WatsonxAiChatResponse; import org.springframework.ai.watsonx.utils.MessageToPromptConverter; import org.springframework.util.Assert; @@ -78,9 +78,9 @@ public WatsonxAiChatModel(WatsonxAiApi watsonxAiApi, WatsonxAiChatOptions defaul @Override public ChatResponse call(Prompt prompt) { - WatsonxAiRequest request = request(prompt); + WatsonxAiChatRequest request = request(prompt); - WatsonxAiResponse response = this.watsonxAiApi.generate(request).getBody(); + WatsonxAiChatResponse response = this.watsonxAiApi.generate(request).getBody(); var generator = new Generation(response.results().get(0).generatedText()); generator = generator.withGenerationMetadata( @@ -92,9 +92,9 @@ public ChatResponse call(Prompt prompt) { @Override public Flux stream(Prompt prompt) { - WatsonxAiRequest request = request(prompt); + WatsonxAiChatRequest request = request(prompt); - Flux response = this.watsonxAiApi.generateStreaming(request); + Flux response = this.watsonxAiApi.generateStreaming(request); return response.map(chunk -> { Generation generation = new Generation(chunk.results().get(0).generatedText()); @@ -106,7 +106,7 @@ public Flux stream(Prompt prompt) { }); } - public WatsonxAiRequest request(Prompt prompt) { + public WatsonxAiChatRequest request(Prompt prompt) { WatsonxAiChatOptions options = WatsonxAiChatOptions.builder().build(); @@ -133,7 +133,7 @@ public WatsonxAiRequest request(Prompt prompt) { .withHumanPrompt("") .toPrompt(prompt.getInstructions()); - return WatsonxAiRequest.builder(convertedPrompt).withParameters(parameters).build(); + return WatsonxAiChatRequest.builder(convertedPrompt).withParameters(parameters).build(); } @Override diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java new file mode 100644 index 00000000000..00aaa1d0985 --- /dev/null +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModel.java @@ -0,0 +1,85 @@ +package org.springframework.ai.watsonx; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.*; +import org.springframework.ai.watsonx.api.WatsonxAiApi; +import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingRequest; +import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingResponse; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * {@link EmbeddingModel} implementation for {@literal Watsonx.ai}. + * + * Watsonx.ai allows developers to run large language models and generate embeddings. It + * supports open-source models available on [Watsonx.ai + * models](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx). + * + * Please refer to the official + * Watsonx.ai website for the most up-to-date information on available models. + * + * @author Pablo Sanchidrian Herrera + * @since 1.0.0 + */ +public class WatsonxAiEmbeddingModel extends AbstractEmbeddingModel { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final WatsonxAiApi watsonxAiApi; + + /** + * Default options to be used for all embedding requests. + */ + private WatsonxAiEmbeddingOptions defaultOptions = WatsonxAiEmbeddingOptions.create() + .withModel(WatsonxAiEmbeddingOptions.DEFAULT_MODEL); + + public WatsonxAiEmbeddingModel(WatsonxAiApi watsonxAiApi) { + this.watsonxAiApi = watsonxAiApi; + } + + public WatsonxAiEmbeddingModel(WatsonxAiApi watsonxAiApi, WatsonxAiEmbeddingOptions defaultOptions) { + this.watsonxAiApi = watsonxAiApi; + this.defaultOptions = defaultOptions; + } + + @Override + public float[] embed(Document document) { + return embed(document.getContent()); + } + + @Override + public EmbeddingResponse call(EmbeddingRequest request) { + Assert.notEmpty(request.getInstructions(), "At least one text is required!"); + + WatsonxAiEmbeddingRequest embeddingRequest = watsonxAiEmbeddingRequest(request.getInstructions(), + request.getOptions()); + WatsonxAiEmbeddingResponse response = this.watsonxAiApi.embeddings(embeddingRequest).getBody(); + + AtomicInteger indexCounter = new AtomicInteger(0); + List embeddings = response.results() + .stream() + .map(e -> new Embedding(e.embedding(), indexCounter.getAndIncrement())) + .toList(); + + return new EmbeddingResponse(embeddings); + } + + WatsonxAiEmbeddingRequest watsonxAiEmbeddingRequest(List inputs, EmbeddingOptions options) { + + WatsonxAiEmbeddingOptions runtimeOptions = (options instanceof WatsonxAiEmbeddingOptions) + ? (WatsonxAiEmbeddingOptions) options : this.defaultOptions; + + if (!StringUtils.hasText(runtimeOptions.getModel())) { + this.logger.warn("The model cannot be null, using default model instead"); + runtimeOptions = this.defaultOptions; + } + + return WatsonxAiEmbeddingRequest.builder(inputs).withModel(runtimeOptions.getModel()).build(); + } + +} diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingOptions.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingOptions.java new file mode 100644 index 00000000000..9db6b6dd517 --- /dev/null +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingOptions.java @@ -0,0 +1,56 @@ +package org.springframework.ai.watsonx; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.embedding.EmbeddingOptions; + +/** + * The configuration information for the embedding requests. + * + * @author Pablo Sanchidrian Herrera + * @since 1.0.0 + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class WatsonxAiEmbeddingOptions implements EmbeddingOptions { + + public static final String DEFAULT_MODEL = "ibm/slate-30m-english-rtrvr"; + + /** + * The embedding model identifier + */ + @JsonProperty("model_id") + private String model; + + public WatsonxAiEmbeddingOptions withModel(String model) { + this.model = model; + return this; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + @JsonIgnore + public Integer getDimensions() { + return null; + } + + /** + * Helper factory method to create a new {@link WatsonxAiEmbeddingOptions} instance. + * @return A new {@link WatsonxAiEmbeddingOptions} instance. + */ + public static WatsonxAiEmbeddingOptions create() { + return new WatsonxAiEmbeddingOptions(); + } + + public static WatsonxAiEmbeddingOptions fromOptions(WatsonxAiEmbeddingOptions fromOptions) { + return new WatsonxAiEmbeddingOptions().withModel(fromOptions.getModel()); + } + +} diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java index 30673ce617b..2de2f36fd06 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java @@ -52,6 +52,7 @@ public class WatsonxAiApi { private final IamAuthenticator iamAuthenticator; private final String streamEndpoint; private final String textEndpoint; + private final String embeddingEndpoint; private final String projectId; private IamToken token; @@ -60,6 +61,7 @@ public class WatsonxAiApi { * @param baseUrl api base URL. * @param streamEndpoint streaming generation. * @param textEndpoint text generation. + * @param embeddingEndpoint embedding generation * @param projectId watsonx.ai project identifier. * @param IAMToken IBM Cloud IAM token. * @param restClientBuilder rest client builder. @@ -68,12 +70,14 @@ public WatsonxAiApi( String baseUrl, String streamEndpoint, String textEndpoint, + String embeddingEndpoint, String projectId, String IAMToken, RestClient.Builder restClientBuilder ) { this.streamEndpoint = streamEndpoint; this.textEndpoint = textEndpoint; + this.embeddingEndpoint = embeddingEndpoint; this.projectId = projectId; this.iamAuthenticator = IamAuthenticator.fromConfiguration(Map.of("APIKEY", IAMToken)); this.token = this.iamAuthenticator.requestToken(); @@ -94,8 +98,8 @@ public WatsonxAiApi( } @Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5)) - public ResponseEntity generate(WatsonxAiRequest watsonxAiRequest) { - Assert.notNull(watsonxAiRequest, WATSONX_REQUEST_CANNOT_BE_NULL); + public ResponseEntity generate(WatsonxAiChatRequest watsonxAiChatRequest) { + Assert.notNull(watsonxAiChatRequest, WATSONX_REQUEST_CANNOT_BE_NULL); if(this.token.needsRefresh()) { this.token = this.iamAuthenticator.requestToken(); @@ -104,14 +108,14 @@ public ResponseEntity generate(WatsonxAiRequest watsonxAiRequ return this.restClient.post() .uri(this.textEndpoint) .header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken()) - .body(watsonxAiRequest.withProjectId(projectId)) + .body(watsonxAiChatRequest.withProjectId(projectId)) .retrieve() - .toEntity(WatsonxAiResponse.class); + .toEntity(WatsonxAiChatResponse.class); } @Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5)) - public Flux generateStreaming(WatsonxAiRequest watsonxAiRequest) { - Assert.notNull(watsonxAiRequest, WATSONX_REQUEST_CANNOT_BE_NULL); + public Flux generateStreaming(WatsonxAiChatRequest watsonxAiChatRequest) { + Assert.notNull(watsonxAiChatRequest, WATSONX_REQUEST_CANNOT_BE_NULL); if(this.token.needsRefresh()) { this.token = this.iamAuthenticator.requestToken(); @@ -120,9 +124,9 @@ public Flux generateStreaming(WatsonxAiRequest watsonxAiReque return this.webClient.post() .uri(this.streamEndpoint) .header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken()) - .bodyValue(watsonxAiRequest.withProjectId(this.projectId)) + .bodyValue(watsonxAiChatRequest.withProjectId(this.projectId)) .retrieve() - .bodyToFlux(WatsonxAiResponse.class) + .bodyToFlux(WatsonxAiChatResponse.class) .handle((data, sink) -> { if (logger.isTraceEnabled()) { logger.trace(data); @@ -131,4 +135,21 @@ public Flux generateStreaming(WatsonxAiRequest watsonxAiReque }); } + @Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5)) + public ResponseEntity embeddings(WatsonxAiEmbeddingRequest request) { + Assert.notNull(request, WATSONX_REQUEST_CANNOT_BE_NULL); + + if(this.token.needsRefresh()) { + this.token = this.iamAuthenticator.requestToken(); + } + + return this.restClient.post() + .uri(this.embeddingEndpoint) + .header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken()) + .body(request.withProjectId(projectId)) + .retrieve() + .toEntity(WatsonxAiEmbeddingResponse.class); + } + + } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiRequest.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java similarity index 83% rename from models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiRequest.java rename to models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java index 2ca88e25864..c228372cbcd 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiRequest.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java @@ -23,9 +23,15 @@ import org.springframework.ai.watsonx.WatsonxAiChatOptions; import org.springframework.util.Assert; +/** + * Java class for Watsonx.ai Chat Request object. + * + * @author Pablo Sanchidrian Herrera + * @since 1.0.0 + */ // @formatter:off @JsonInclude(JsonInclude.Include.NON_NULL) -public class WatsonxAiRequest { +public class WatsonxAiChatRequest { @JsonProperty("input") private String input; @@ -36,19 +42,14 @@ public class WatsonxAiRequest { @JsonProperty("project_id") private String projectId = ""; - private WatsonxAiRequest(String input, Map parameters, String modelId, String projectId) { + private WatsonxAiChatRequest(String input, Map parameters, String modelId, String projectId) { this.input = input; this.parameters = parameters; this.modelId = modelId; this.projectId = projectId; } - public WatsonxAiRequest withModelId(String modelId) { - this.modelId = modelId; - return this; - } - - public WatsonxAiRequest withProjectId(String projectId) { + public WatsonxAiChatRequest withProjectId(String projectId) { this.projectId = projectId; return this; } @@ -79,8 +80,8 @@ public Builder withParameters(Map parameters) { return this; } - public WatsonxAiRequest build() { - return new WatsonxAiRequest(input, parameters, model, ""); + public WatsonxAiChatRequest build() { + return new WatsonxAiChatRequest(input, parameters, model, ""); } } diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiResponse.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java similarity index 82% rename from models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiResponse.java rename to models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java index dd776626867..36127771b35 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiResponse.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResponse.java @@ -22,11 +22,17 @@ import java.util.List; import java.util.Map; +/** + * Java class for Watsonx.ai Chat Response object. + * + * @author Pablo Sanchidrian Herrera + * @since 1.0.0 + */ // @formatter:off @JsonInclude(JsonInclude.Include.NON_NULL) -public record WatsonxAiResponse( +public record WatsonxAiChatResponse( @JsonProperty("model_id") String modelId, @JsonProperty("created_at") Date createdAt, - @JsonProperty("results") List results, + @JsonProperty("results") List results, @JsonProperty("system") Map system ) {} diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiResults.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java similarity index 88% rename from models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiResults.java rename to models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java index a0d28b90060..316f8c2e478 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiResults.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatResults.java @@ -18,9 +18,15 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +/** + * Java class for Watsonx.ai Chat Results object. + * + * @author Pablo Sanchidrian Herrera + * @since 1.0.0 + */ // @formatter:off @JsonInclude(JsonInclude.Include.NON_NULL) -public record WatsonxAiResults( +public record WatsonxAiChatResults( @JsonProperty("generated_text") String generatedText, @JsonProperty("generated_token_count") Integer generatedTokenCount, @JsonProperty("input_token_count") Integer inputTokenCount, diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java new file mode 100644 index 00000000000..331dfa0a1af --- /dev/null +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingRequest.java @@ -0,0 +1,71 @@ +package org.springframework.ai.watsonx.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.watsonx.WatsonxAiEmbeddingOptions; + +import java.util.List; + +/** + * Java class for Watsonx.ai Embedding Request object. + * + * @author Pablo Sanchidrian Herrera + * @since 1.0.0 + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class WatsonxAiEmbeddingRequest { + + @JsonProperty("model_id") + String model; + + @JsonProperty("inputs") + List inputs; + + @JsonProperty("project_id") + String projectId; + + public String getModel() { + return model; + } + + public List getInputs() { + return inputs; + } + + private WatsonxAiEmbeddingRequest(String model, List inputs, String projectId) { + this.model = model; + this.inputs = inputs; + this.projectId = projectId; + } + + public WatsonxAiEmbeddingRequest withProjectId(String projectId) { + this.projectId = projectId; + return this; + } + + public static Builder builder(List inputs) { + return new Builder(inputs); + } + + public static class Builder { + + private String model = WatsonxAiEmbeddingOptions.DEFAULT_MODEL; + + private final List inputs; + + public Builder(List inputs) { + this.inputs = inputs; + } + + public Builder withModel(String model) { + this.model = model; + return this; + } + + public WatsonxAiEmbeddingRequest build() { + return new WatsonxAiEmbeddingRequest(model, inputs, ""); + } + + } + +} diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResponse.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResponse.java new file mode 100644 index 00000000000..ec1ae022605 --- /dev/null +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResponse.java @@ -0,0 +1,19 @@ +package org.springframework.ai.watsonx.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Date; +import java.util.List; + +/** + * Java class for Watsonx.ai Embedding Response object. + * + * @author Pablo Sanchidrian Herrera + * @since 1.0.0 + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public record WatsonxAiEmbeddingResponse(@JsonProperty("model_id") String model, + @JsonProperty("created_at") Date createdAt, @JsonProperty("results") List results, + @JsonProperty("input_token_count") Integer inputTokenCount) { +} diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResults.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResults.java new file mode 100644 index 00000000000..975a1195e9e --- /dev/null +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingResults.java @@ -0,0 +1,16 @@ +package org.springframework.ai.watsonx.api; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +/** + * Java class for Watsonx.ai Embedding Results object. + * + * @author Pablo Sanchidrian Herrera + * @since 1.0.0 + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public record WatsonxAiEmbeddingResults(@JsonProperty("embedding") float[] embedding) { +} diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java index cf4f1c739d1..2d860422f32 100644 --- a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatModelTest.java @@ -32,9 +32,9 @@ import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.watsonx.api.WatsonxAiApi; -import org.springframework.ai.watsonx.api.WatsonxAiRequest; -import org.springframework.ai.watsonx.api.WatsonxAiResponse; -import org.springframework.ai.watsonx.api.WatsonxAiResults; +import org.springframework.ai.watsonx.api.WatsonxAiChatRequest; +import org.springframework.ai.watsonx.api.WatsonxAiChatResponse; +import org.springframework.ai.watsonx.api.WatsonxAiChatResults; import org.springframework.http.ResponseEntity; import static org.assertj.core.api.Assertions.assertThat; @@ -57,7 +57,7 @@ public void testCreateRequestWithNoModelId() { Prompt prompt = new Prompt("Test message", options); Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> { - WatsonxAiRequest request = chatModel.request(prompt); + WatsonxAiChatRequest request = chatModel.request(prompt); }); } @@ -71,7 +71,7 @@ public void testCreateRequestSuccessfullyWithDefaultParams() { .build(); Prompt prompt = new Prompt(msg, modelOptions); - WatsonxAiRequest request = chatModel.request(prompt); + WatsonxAiChatRequest request = chatModel.request(prompt); Assert.assertEquals(request.getModelId(), "meta-llama/llama-2-70b-chat"); assertThat(request.getParameters().get("decoding_method")).isEqualTo("greedy"); @@ -105,7 +105,7 @@ public void testCreateRequestSuccessfullyWithNonDefaultParams() { Prompt prompt = new Prompt(msg, modelOptions); - WatsonxAiRequest request = chatModel.request(prompt); + WatsonxAiChatRequest request = chatModel.request(prompt); Assert.assertEquals(request.getModelId(), "meta-llama/llama-2-70b-chat"); assertThat(request.getParameters().get("decoding_method")).isEqualTo("sample"); @@ -139,7 +139,7 @@ public void testCreateRequestSuccessfullyWithChatDisabled() { Prompt prompt = new Prompt(msg, modelOptions); - WatsonxAiRequest request = chatModel.request(prompt); + WatsonxAiChatRequest request = chatModel.request(prompt); Assert.assertEquals(request.getModelId(), "meta-llama/llama-2-70b-chat"); assertThat(request.getInput()).isEqualTo(msg); @@ -164,12 +164,13 @@ public void testCallMethod() { WatsonxAiChatOptions parameters = WatsonxAiChatOptions.builder().withModel("google/flan-ul2").build(); - WatsonxAiResults fakeResults = new WatsonxAiResults("LLM response", 4, 3, "max_tokens"); + WatsonxAiChatResults fakeResults = new WatsonxAiChatResults("LLM response", 4, 3, "max_tokens"); - WatsonxAiResponse fakeResponse = new WatsonxAiResponse("google/flan-ul2", new Date(), List.of(fakeResults), + WatsonxAiChatResponse fakeResponse = new WatsonxAiChatResponse("google/flan-ul2", new Date(), + List.of(fakeResults), Map.of("warnings", List.of(Map.of("message", "the message", "id", "disclaimer_warning")))); - when(mockChatApi.generate(any(WatsonxAiRequest.class))) + when(mockChatApi.generate(any(WatsonxAiChatRequest.class))) .thenReturn(ResponseEntity.of(Optional.of(fakeResponse))); Generation expectedGenerator = new Generation("LLM response") @@ -193,17 +194,17 @@ public void testStreamMethod() { WatsonxAiChatOptions parameters = WatsonxAiChatOptions.builder().withModel("google/flan-ul2").build(); - WatsonxAiResults fakeResultsFirst = new WatsonxAiResults("LLM resp", 0, 0, "max_tokens"); - WatsonxAiResults fakeResultsSecond = new WatsonxAiResults("onse", 4, 3, "not_finished"); + WatsonxAiChatResults fakeResultsFirst = new WatsonxAiChatResults("LLM resp", 0, 0, "max_tokens"); + WatsonxAiChatResults fakeResultsSecond = new WatsonxAiChatResults("onse", 4, 3, "not_finished"); - WatsonxAiResponse fakeResponseFirst = new WatsonxAiResponse("google/flan-ul2", new Date(), + WatsonxAiChatResponse fakeResponseFirst = new WatsonxAiChatResponse("google/flan-ul2", new Date(), List.of(fakeResultsFirst), Map.of("warnings", List.of(Map.of("message", "the message", "id", "disclaimer_warning")))); - WatsonxAiResponse fakeResponseSecond = new WatsonxAiResponse("google/flan-ul2", new Date(), + WatsonxAiChatResponse fakeResponseSecond = new WatsonxAiChatResponse("google/flan-ul2", new Date(), List.of(fakeResultsSecond), null); - Flux fakeResponse = Flux.just(fakeResponseFirst, fakeResponseSecond); - when(mockChatApi.generateStreaming(any(WatsonxAiRequest.class))).thenReturn(fakeResponse); + Flux fakeResponse = Flux.just(fakeResponseFirst, fakeResponseSecond); + when(mockChatApi.generateStreaming(any(WatsonxAiChatRequest.class))).thenReturn(fakeResponse); Generation firstGen = new Generation("LLM resp") .withGenerationMetadata(ChatGenerationMetadata.from("max_tokens", diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java new file mode 100644 index 00000000000..42e6c0cc5bf --- /dev/null +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiEmbeddingModelTest.java @@ -0,0 +1,86 @@ +package org.springframework.ai.watsonx; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.watsonx.api.WatsonxAiApi; +import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingRequest; +import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingResponse; +import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingResults; +import org.springframework.http.ResponseEntity; + +import java.util.Date; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class WatsonxAiEmbeddingModelTest { + + private WatsonxAiApi watsonxAiApiMock; + + private final WatsonxAiEmbeddingModel embeddingModel; + + public WatsonxAiEmbeddingModelTest() { + this.watsonxAiApiMock = mock(WatsonxAiApi.class); + this.embeddingModel = new WatsonxAiEmbeddingModel(watsonxAiApiMock); + } + + @Test + void createRequestWithOptions() { + String MODEL = "custom-model"; + List inputs = List.of("test"); + WatsonxAiEmbeddingOptions options = WatsonxAiEmbeddingOptions.create().withModel(MODEL); + + WatsonxAiEmbeddingRequest request = embeddingModel.watsonxAiEmbeddingRequest(inputs, options); + + assertThat(request.getModel()).isEqualTo(MODEL); + assertThat(request.getInputs().size()).isEqualTo(inputs.size()); + } + + @Test + void createRequestWithOptionsAndInvalidModel() { + String MODEL = ""; + List inputs = List.of("test"); + WatsonxAiEmbeddingOptions options = WatsonxAiEmbeddingOptions.create().withModel(MODEL); + + WatsonxAiEmbeddingRequest request = embeddingModel.watsonxAiEmbeddingRequest(inputs, options); + + assertThat(request.getModel()).isEqualTo(WatsonxAiEmbeddingOptions.DEFAULT_MODEL); + assertThat(request.getInputs().size()).isEqualTo(inputs.size()); + } + + @Test + void createRequestWithNoOptions() { + List inputs = List.of("test"); + WatsonxAiEmbeddingRequest request = embeddingModel.watsonxAiEmbeddingRequest(inputs, EmbeddingOptions.EMPTY); + + assertThat(request.getModel()).isEqualTo(WatsonxAiEmbeddingOptions.DEFAULT_MODEL); + assertThat(request.getInputs().size()).isEqualTo(inputs.size()); + } + + @Test + void singleEmbeddingWithOptions() { + List inputs = List.of("test"); + + String modelId = "mockId"; + Integer inputTokenCount = 2; + float[] vector = new float[] { 1.0f, 2.0f }; + List mockResults = List.of(new WatsonxAiEmbeddingResults(vector)); + WatsonxAiEmbeddingResponse mockResponse = new WatsonxAiEmbeddingResponse(modelId, new Date(), mockResults, + inputTokenCount); + + ResponseEntity mockResponseEntity = ResponseEntity.ok(mockResponse); + when(watsonxAiApiMock.embeddings(any(WatsonxAiEmbeddingRequest.class))).thenReturn(mockResponseEntity); + + assertThat(embeddingModel).isNotNull(); + + EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World")); + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty(); + assertThat(embeddingModel.dimensions()).isEqualTo(2); + } + +} diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingOptionTest.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingOptionTest.java new file mode 100644 index 00000000000..98d63092f7b --- /dev/null +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingOptionTest.java @@ -0,0 +1,36 @@ +package org.springframework.ai.watsonx.api; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.Test; +import org.springframework.ai.watsonx.WatsonxAiEmbeddingOptions; + +/** + * @author Pablo Sanchidrian Herrera + */ +public class WatsonxAiEmbeddingOptionTest { + + @Test + public void testWithModel() { + WatsonxAiEmbeddingOptions options = new WatsonxAiEmbeddingOptions(); + options.withModel("test-model"); + assertThat("test-model").isEqualTo(options.getModel()); + } + + @Test + public void testCreateFactoryMethod() { + WatsonxAiEmbeddingOptions options = WatsonxAiEmbeddingOptions.create(); + assertThat(options).isNotNull(); + assertThat(options.getModel()).isNull(); + } + + @Test + public void testFromOptionsFactoryMethod() { + WatsonxAiEmbeddingOptions originalOptions = new WatsonxAiEmbeddingOptions().withModel("original-model"); + WatsonxAiEmbeddingOptions newOptions = WatsonxAiEmbeddingOptions.fromOptions(originalOptions); + + assertThat(newOptions).isNotNull(); + assertThat("original-model").isEqualTo(newOptions.getModel()); + } + +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/watsonx-ai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/watsonx-ai-chat.adoc index 4b085b72b81..e7e8a23b775 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/watsonx-ai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/watsonx-ai-chat.adoc @@ -44,11 +44,11 @@ The prefix `spring.ai.watsonx.ai` is used as the property prefix that lets you c |==== | Property | Description | Default -| spring.ai.watsonx.ai.base-url | The URL to connect to | https://us-south.ml.cloud.ibm.com -| spring.ai.watsonx.ai.stream-endpoint | The streaming endpoint | generation/stream?version=2023-05-29 -| spring.ai.watsonx.ai.text-endpoint | The text endpoint | generation/text?version=2023-05-29 -| spring.ai.watsonx.ai.project-id | The project ID | - -| spring.ai.watsonx.ai.iam-token | The IBM Cloud account IAM token | - +| spring.ai.watsonx.ai.base-url | The URL to connect to | https://us-south.ml.cloud.ibm.com +| spring.ai.watsonx.ai.stream-endpoint | The streaming endpoint | ml/v1/text/generation_stream?version=2023-05-29 +| spring.ai.watsonx.ai.text-endpoint | The text endpoint | ml/v1/text/generation?version=2023-05-29 +| spring.ai.watsonx.ai.project-id | The project ID | - +| spring.ai.watsonx.ai.iam-token | The IBM Cloud account IAM token | - |==== ==== Configuration Properties diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/watsonx-ai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/watsonx-ai-embeddings.adoc new file mode 100644 index 00000000000..fbdc8c53ad3 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/watsonx-ai-embeddings.adoc @@ -0,0 +1,115 @@ += Watsonx.ai Embeddings + +With https://www.ibm.com/products/watsonx-ai[Watsonx.ai] you can run various Large Language Models (LLMs) and generate embeddings from them. +Spring AI supports the Watsonx.ai text embeddings with `WatsonxAiEmbeddingModel`. + +An embedding is a vector (list) of floating point numbers. +The distance between two vectors measures their relatedness. +Small distances suggest high relatedness and large distances suggest low relatedness. + +== Prerequisites + +You first need to have a SaaS instance of watsonx.ai (as well as an IBM Cloud account). + +Refer to https://eu-de.dataplatform.cloud.ibm.com/registration/stepone?context=wx&preselect_region=true[free-trial] to try watsonx.ai for free + +TIP: More info can be found https://www.ibm.com/products/watsonx-ai/info/trial[here] + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the Watsonx.ai Embedding Model. +To enable it add the following dependency to your Maven `pom.xml` file: + +[source,xml] +---- + + org.springframework.ai + spring-ai-watsonx-ai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-watsonx-ai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +The `spring.ai.watsonx.embedding.options.*` properties are used to configure the default options used for all embedding requests. + +=== Embedding Properties + +The prefix `spring.ai.watsonx.ai` is used as the property prefix that lets you connect to watsonx.ai. + +[cols="4,3,3"] +|==== +| Property | Description | Default + +| spring.ai.watsonx.ai.base-url | The URL to connect to | https://us-south.ml.cloud.ibm.com +| spring.ai.watsonx.ai.embedding-endpoint | The text endpoint | ml/v1/text/embeddings?version=2023-05-29 +| spring.ai.watsonx.ai.project-id | The project ID | - +| spring.ai.watsonx.ai.iam-token | The IBM Cloud account IAM token | - +|==== + +The prefix `spring.ai.watsonx.embedding.options` is the property prefix that configures the `EmbeddingModel` implementation for Watsonx.ai. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.watsonx.ai.embedding.enabled | Enable Watsonx.ai embedding model | true +| spring.ai.watsonx.ai.embedding.options.model | The embedding model to be used | ibm/slate-30m-english-rtrvr +|==== + + +== Runtime Options [[embedding-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-watsonx/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiEmbeddingOptions.java[WatsonxAiEmbeddingOptions.java] provides the Watsonx.ai configurations, such as the model to use. + +The default options can be configured using the `spring.ai.watsonx.embedding.options` properties as well. + + +[source,java] +---- +EmbeddingResponse embeddingResponse = embeddingModel.call( + new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), + WatsonxAiEmbeddingOptions.create() + .withModel("Different-Embedding-Model-Deployment-Name")) +); +---- + +== Sample Controller + +This will create a `EmbeddingModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the `EmbeddingModel` implementation. + +[source,java] +---- +@RestController +public class EmbeddingController { + + private final EmbeddingModel embeddingModel; + + @Autowired + public EmbeddingController(EmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + } + + @GetMapping("/ai/embedding") + public ResponseEntity embedding(@RequestParam String text) { + EmbeddingResponse response = this.embedding.embedForResponse(List.of(text)); + return ResponseEntity.ok(response.getResult()); + } +} +---- diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfiguration.java index fd62de3cebb..8cb42373786 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfiguration.java @@ -16,6 +16,7 @@ package org.springframework.ai.autoconfigure.watsonxai; import org.springframework.ai.watsonx.WatsonxAiChatModel; +import org.springframework.ai.watsonx.WatsonxAiEmbeddingModel; import org.springframework.ai.watsonx.api.WatsonxAiApi; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; @@ -36,7 +37,8 @@ */ @AutoConfiguration(after = RestClientAutoConfiguration.class) @ConditionalOnClass(WatsonxAiApi.class) -@EnableConfigurationProperties({ WatsonxAiConnectionProperties.class, WatsonxAiChatProperties.class }) +@EnableConfigurationProperties({ WatsonxAiConnectionProperties.class, WatsonxAiChatProperties.class, + WatsonxAiEmbeddingProperties.class }) @ConditionalOnProperty(prefix = WatsonxAiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) public class WatsonxAiAutoConfiguration { @@ -45,13 +47,25 @@ public class WatsonxAiAutoConfiguration { @ConditionalOnMissingBean public WatsonxAiApi watsonxApi(WatsonxAiConnectionProperties properties, RestClient.Builder restClientBuilder) { return new WatsonxAiApi(properties.getBaseUrl(), properties.getStreamEndpoint(), properties.getTextEndpoint(), - properties.getProjectId(), properties.getIAMToken(), restClientBuilder); + properties.getEmbeddingEndpoint(), properties.getProjectId(), properties.getIAMToken(), + restClientBuilder); } @Bean @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = WatsonxAiChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) public WatsonxAiChatModel watsonxChatModel(WatsonxAiApi watsonxApi, WatsonxAiChatProperties chatProperties) { return new WatsonxAiChatModel(watsonxApi, chatProperties.getOptions()); } + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = WatsonxAiEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", + matchIfMissing = true) + public WatsonxAiEmbeddingModel watsonxAiEmbeddingModel(WatsonxAiApi watsonxApi, + WatsonxAiEmbeddingProperties properties) { + return new WatsonxAiEmbeddingModel(watsonxApi, properties.getOptions()); + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiConnectionProperties.java index 3e6357c6332..0ffc3656d18 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiConnectionProperties.java @@ -31,9 +31,11 @@ public class WatsonxAiConnectionProperties { private String baseUrl = "https://us-south.ml.cloud.ibm.com/"; - private String streamEndpoint = "generation/stream?version=2023-05-29"; + private String streamEndpoint = "ml/v1/text/generation_stream?version=2023-05-29"; - private String textEndpoint = "generation/text?version=2023-05-29"; + private String textEndpoint = "ml/v1/text/generation?version=2023-05-29"; + + private String embeddingEndpoint = "ml/v1/text/embeddings?version=2023-05-29"; private String projectId; @@ -63,6 +65,14 @@ public void setTextEndpoint(String textEndpoint) { this.textEndpoint = textEndpoint; } + public String getEmbeddingEndpoint() { + return embeddingEndpoint; + } + + public void setEmbeddingEndpoint(String embeddingEndpoint) { + this.embeddingEndpoint = embeddingEndpoint; + } + public String getProjectId() { return projectId; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiEmbeddingProperties.java new file mode 100644 index 00000000000..42425a265a7 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiEmbeddingProperties.java @@ -0,0 +1,51 @@ +package org.springframework.ai.autoconfigure.watsonxai; + +import org.springframework.ai.watsonx.WatsonxAiEmbeddingOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * Watsonx.ai Embedding autoconfiguration properties. + * + * @author Pablo Sanchidrian Herrera + * @since 1.0.0 + */ +@ConfigurationProperties(WatsonxAiEmbeddingProperties.CONFIG_PREFIX) +public class WatsonxAiEmbeddingProperties { + + public static final String CONFIG_PREFIX = "spring.ai.watsonx.ai.embedding"; + + /** + * Enable Watsonx.ai embedding model. + */ + private boolean enabled = true; + + /** + * Client lever Watsonx.ai embedding options. Use this property to configure the + * model. The null values are ignored defaulting to the defaults. + */ + @NestedConfigurationProperty + private WatsonxAiEmbeddingOptions options = WatsonxAiEmbeddingOptions.create() + .withModel(WatsonxAiEmbeddingOptions.DEFAULT_MODEL); + + public String getModel() { + return this.options.getModel(); + } + + public void setModel(String model) { + this.options.setModel(model); + } + + public WatsonxAiEmbeddingOptions getOptions() { + return this.options; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public boolean isEnabled() { + return this.enabled; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java index 6a8831cfe5c..1637b204ebf 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiAutoConfigurationTests.java @@ -29,8 +29,9 @@ public void propertiesTest() { new ApplicationContextRunner().withPropertyValues( // @formatter:off "spring.ai.watsonx.ai.base-url=TEST_BASE_URL", - "spring.ai.watsonx.ai.stream-endpoint=generation/stream?version=2023-05-29", - "spring.ai.watsonx.ai.text-endpoint=generation/text?version=2023-05-29", + "spring.ai.watsonx.ai.stream-endpoint=ml/v1/text/generation_stream?version=2023-05-29", + "spring.ai.watsonx.ai.text-endpoint=ml/v1/text/generation?version=2023-05-29", + "spring.ai.watsonx.ai.embedding-endpoint=ml/v1/text/embeddings?version=2023-05-29", "spring.ai.watsonx.ai.projectId=1", "spring.ai.watsonx.ai.IAMToken=123456") // @formatter:on @@ -39,8 +40,12 @@ public void propertiesTest() { .run(context -> { var connectionProperties = context.getBean(WatsonxAiConnectionProperties.class); assertThat(connectionProperties.getBaseUrl()).isEqualTo("TEST_BASE_URL"); - assertThat(connectionProperties.getStreamEndpoint()).isEqualTo("generation/stream?version=2023-05-29"); - assertThat(connectionProperties.getTextEndpoint()).isEqualTo("generation/text?version=2023-05-29"); + assertThat(connectionProperties.getStreamEndpoint()) + .isEqualTo("ml/v1/text/generation_stream?version=2023-05-29"); + assertThat(connectionProperties.getTextEndpoint()) + .isEqualTo("ml/v1/text/generation?version=2023-05-29"); + assertThat(connectionProperties.getEmbeddingEndpoint()) + .isEqualTo("ml/v1/text/embeddings?version=2023-05-29"); assertThat(connectionProperties.getProjectId()).isEqualTo("1"); assertThat(connectionProperties.getIAMToken()).isEqualTo("123456"); });