Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9f18727
feat: update chat endpoints default values
PabloSanchi Jul 11, 2024
118f282
feat: add embedding endpoint to watsonx connection properties
PabloSanchi Jul 11, 2024
546b724
feat: add getter and setter for the embedding endpoint parameter
PabloSanchi Jul 11, 2024
1e4535e
feat: add embedding endpoint into watsonx api
PabloSanchi Jul 11, 2024
6c4f282
feat: add embedding endpoint info
PabloSanchi Jul 11, 2024
1d978e8
feat: add embedding api call & request and response classes
PabloSanchi Jul 12, 2024
fa50a55
feat: add embedding result record
PabloSanchi Jul 12, 2024
70fc2ea
feat: rename chat classes/types for better understanding and clarity
PabloSanchi Jul 12, 2024
9d48f83
fix: linter
PabloSanchi Jul 12, 2024
7d2ad06
feat: add watsonx embedding options class and tests
PabloSanchi Jul 12, 2024
aeeb444
feat: add watsonx embedding model
PabloSanchi Jul 12, 2024
4b039e0
feat: add watsonx embedding model autoconfiguration and properties
PabloSanchi Jul 12, 2024
319893f
fix: remove unused import
PabloSanchi Jul 12, 2024
3b1b54d
fix: remove unused method
PabloSanchi Jul 12, 2024
d159d57
fix: remove embedding info in watsonx chat docs
PabloSanchi Jul 12, 2024
69ec0ac
feat: add watsonx embedding model tests
PabloSanchi Jul 12, 2024
3f5b1bd
feat: add watsonx embedding model docs
PabloSanchi Jul 12, 2024
1965b64
Merge branch 'spring-projects:main' into main
PabloSanchi Jul 12, 2024
3621a8a
fix: config prefix, fix wrong path
PabloSanchi Jul 12, 2024
9930d1f
Merge branch 'spring-projects:main' into main
PabloSanchi Jul 12, 2024
6a8afeb
Merge branch 'spring-projects:main' into main
PabloSanchi Jul 12, 2024
81254f3
feat: add class description and sign
PabloSanchi Jul 12, 2024
666822f
fix: issue with property name, 'ai' missing
PabloSanchi Jul 13, 2024
3a34653
fix: update embedding controller example
PabloSanchi Jul 13, 2024
280b16b
fix: wrong default embedding endpoint
PabloSanchi Jul 13, 2024
20ae89f
Merge branch 'spring-projects:main' into main
PabloSanchi Jul 16, 2024
f45104e
Merge branch 'spring-projects:main' into main
PabloSanchi Aug 23, 2024
1ffdfa3
fix: update code to met interfaces requirements
PabloSanchi Aug 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.watsonx.api.WatsonxAiApi;
import org.springframework.ai.watsonx.api.WatsonxAiRequest;
import org.springframework.ai.watsonx.api.WatsonxAiResponse;
import org.springframework.ai.watsonx.api.WatsonxAiChatRequest;
import org.springframework.ai.watsonx.api.WatsonxAiChatResponse;
import org.springframework.ai.watsonx.utils.MessageToPromptConverter;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -78,9 +78,9 @@ public WatsonxAiChatModel(WatsonxAiApi watsonxAiApi, WatsonxAiChatOptions defaul
@Override
public ChatResponse call(Prompt prompt) {

WatsonxAiRequest request = request(prompt);
WatsonxAiChatRequest request = request(prompt);

WatsonxAiResponse response = this.watsonxAiApi.generate(request).getBody();
WatsonxAiChatResponse response = this.watsonxAiApi.generate(request).getBody();
var generator = new Generation(response.results().get(0).generatedText());

generator = generator.withGenerationMetadata(
Expand All @@ -92,9 +92,9 @@ public ChatResponse call(Prompt prompt) {
@Override
public Flux<ChatResponse> stream(Prompt prompt) {

WatsonxAiRequest request = request(prompt);
WatsonxAiChatRequest request = request(prompt);

Flux<WatsonxAiResponse> response = this.watsonxAiApi.generateStreaming(request);
Flux<WatsonxAiChatResponse> response = this.watsonxAiApi.generateStreaming(request);

return response.map(chunk -> {
Generation generation = new Generation(chunk.results().get(0).generatedText());
Expand All @@ -106,7 +106,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
});
}

public WatsonxAiRequest request(Prompt prompt) {
public WatsonxAiChatRequest request(Prompt prompt) {

WatsonxAiChatOptions options = WatsonxAiChatOptions.builder().build();

Expand All @@ -133,7 +133,7 @@ public WatsonxAiRequest request(Prompt prompt) {
.withHumanPrompt("")
.toPrompt(prompt.getInstructions());

return WatsonxAiRequest.builder(convertedPrompt).withParameters(parameters).build();
return WatsonxAiChatRequest.builder(convertedPrompt).withParameters(parameters).build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package org.springframework.ai.watsonx;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.*;
import org.springframework.ai.watsonx.api.WatsonxAiApi;
import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingRequest;
import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingResponse;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

/**
* {@link EmbeddingModel} implementation for {@literal Watsonx.ai}.
*
* Watsonx.ai allows developers to run large language models and generate embeddings. It
* supports open-source models available on [Watsonx.ai
* models](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx).
*
* Please refer to the <a href="https://www.ibm.com/products/watsonx-ai/">official
* Watsonx.ai website</a> for the most up-to-date information on available models.
*
* @author Pablo Sanchidrian Herrera
* @since 1.0.0
*/
public class WatsonxAiEmbeddingModel extends AbstractEmbeddingModel {

private final Logger logger = LoggerFactory.getLogger(getClass());

private final WatsonxAiApi watsonxAiApi;

/**
* Default options to be used for all embedding requests.
*/
private WatsonxAiEmbeddingOptions defaultOptions = WatsonxAiEmbeddingOptions.create()
.withModel(WatsonxAiEmbeddingOptions.DEFAULT_MODEL);

public WatsonxAiEmbeddingModel(WatsonxAiApi watsonxAiApi) {
this.watsonxAiApi = watsonxAiApi;
}

public WatsonxAiEmbeddingModel(WatsonxAiApi watsonxAiApi, WatsonxAiEmbeddingOptions defaultOptions) {
this.watsonxAiApi = watsonxAiApi;
this.defaultOptions = defaultOptions;
}

@Override
public float[] embed(Document document) {
return embed(document.getContent());
}

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");

WatsonxAiEmbeddingRequest embeddingRequest = watsonxAiEmbeddingRequest(request.getInstructions(),
request.getOptions());
WatsonxAiEmbeddingResponse response = this.watsonxAiApi.embeddings(embeddingRequest).getBody();

AtomicInteger indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = response.results()
.stream()
.map(e -> new Embedding(e.embedding(), indexCounter.getAndIncrement()))
.toList();

return new EmbeddingResponse(embeddings);
}

WatsonxAiEmbeddingRequest watsonxAiEmbeddingRequest(List<String> inputs, EmbeddingOptions options) {

WatsonxAiEmbeddingOptions runtimeOptions = (options instanceof WatsonxAiEmbeddingOptions)
? (WatsonxAiEmbeddingOptions) options : this.defaultOptions;

if (!StringUtils.hasText(runtimeOptions.getModel())) {
this.logger.warn("The model cannot be null, using default model instead");
runtimeOptions = this.defaultOptions;
}

return WatsonxAiEmbeddingRequest.builder(inputs).withModel(runtimeOptions.getModel()).build();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package org.springframework.ai.watsonx;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.embedding.EmbeddingOptions;

/**
* The configuration information for the embedding requests.
*
* @author Pablo Sanchidrian Herrera
* @since 1.0.0
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public class WatsonxAiEmbeddingOptions implements EmbeddingOptions {

public static final String DEFAULT_MODEL = "ibm/slate-30m-english-rtrvr";

/**
* The embedding model identifier
*/
@JsonProperty("model_id")
private String model;

public WatsonxAiEmbeddingOptions withModel(String model) {
this.model = model;
return this;
}

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

@Override
@JsonIgnore
public Integer getDimensions() {
return null;
}

/**
* Helper factory method to create a new {@link WatsonxAiEmbeddingOptions} instance.
* @return A new {@link WatsonxAiEmbeddingOptions} instance.
*/
public static WatsonxAiEmbeddingOptions create() {
return new WatsonxAiEmbeddingOptions();
}

public static WatsonxAiEmbeddingOptions fromOptions(WatsonxAiEmbeddingOptions fromOptions) {
return new WatsonxAiEmbeddingOptions().withModel(fromOptions.getModel());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class WatsonxAiApi {
private final IamAuthenticator iamAuthenticator;
private final String streamEndpoint;
private final String textEndpoint;
private final String embeddingEndpoint;
private final String projectId;
private IamToken token;

Expand All @@ -60,6 +61,7 @@ public class WatsonxAiApi {
* @param baseUrl api base URL.
* @param streamEndpoint streaming generation.
* @param textEndpoint text generation.
* @param embeddingEndpoint embedding generation
* @param projectId watsonx.ai project identifier.
* @param IAMToken IBM Cloud IAM token.
* @param restClientBuilder rest client builder.
Expand All @@ -68,12 +70,14 @@ public WatsonxAiApi(
String baseUrl,
String streamEndpoint,
String textEndpoint,
String embeddingEndpoint,
String projectId,
String IAMToken,
RestClient.Builder restClientBuilder
) {
this.streamEndpoint = streamEndpoint;
this.textEndpoint = textEndpoint;
this.embeddingEndpoint = embeddingEndpoint;
this.projectId = projectId;
this.iamAuthenticator = IamAuthenticator.fromConfiguration(Map.of("APIKEY", IAMToken));
this.token = this.iamAuthenticator.requestToken();
Expand All @@ -94,8 +98,8 @@ public WatsonxAiApi(
}

@Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5))
public ResponseEntity<WatsonxAiResponse> generate(WatsonxAiRequest watsonxAiRequest) {
Assert.notNull(watsonxAiRequest, WATSONX_REQUEST_CANNOT_BE_NULL);
public ResponseEntity<WatsonxAiChatResponse> generate(WatsonxAiChatRequest watsonxAiChatRequest) {
Assert.notNull(watsonxAiChatRequest, WATSONX_REQUEST_CANNOT_BE_NULL);

if(this.token.needsRefresh()) {
this.token = this.iamAuthenticator.requestToken();
Expand All @@ -104,14 +108,14 @@ public ResponseEntity<WatsonxAiResponse> generate(WatsonxAiRequest watsonxAiRequ
return this.restClient.post()
.uri(this.textEndpoint)
.header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken())
.body(watsonxAiRequest.withProjectId(projectId))
.body(watsonxAiChatRequest.withProjectId(projectId))
.retrieve()
.toEntity(WatsonxAiResponse.class);
.toEntity(WatsonxAiChatResponse.class);
}

@Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5))
public Flux<WatsonxAiResponse> generateStreaming(WatsonxAiRequest watsonxAiRequest) {
Assert.notNull(watsonxAiRequest, WATSONX_REQUEST_CANNOT_BE_NULL);
public Flux<WatsonxAiChatResponse> generateStreaming(WatsonxAiChatRequest watsonxAiChatRequest) {
Assert.notNull(watsonxAiChatRequest, WATSONX_REQUEST_CANNOT_BE_NULL);

if(this.token.needsRefresh()) {
this.token = this.iamAuthenticator.requestToken();
Expand All @@ -120,9 +124,9 @@ public Flux<WatsonxAiResponse> generateStreaming(WatsonxAiRequest watsonxAiReque
return this.webClient.post()
.uri(this.streamEndpoint)
.header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken())
.bodyValue(watsonxAiRequest.withProjectId(this.projectId))
.bodyValue(watsonxAiChatRequest.withProjectId(this.projectId))
.retrieve()
.bodyToFlux(WatsonxAiResponse.class)
.bodyToFlux(WatsonxAiChatResponse.class)
.handle((data, sink) -> {
if (logger.isTraceEnabled()) {
logger.trace(data);
Expand All @@ -131,4 +135,21 @@ public Flux<WatsonxAiResponse> generateStreaming(WatsonxAiRequest watsonxAiReque
});
}

@Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5))
public ResponseEntity<WatsonxAiEmbeddingResponse> embeddings(WatsonxAiEmbeddingRequest request) {
Assert.notNull(request, WATSONX_REQUEST_CANNOT_BE_NULL);

if(this.token.needsRefresh()) {
this.token = this.iamAuthenticator.requestToken();
}

return this.restClient.post()
.uri(this.embeddingEndpoint)
.header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken())
.body(request.withProjectId(projectId))
.retrieve()
.toEntity(WatsonxAiEmbeddingResponse.class);
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@
import org.springframework.ai.watsonx.WatsonxAiChatOptions;
import org.springframework.util.Assert;

/**
* Java class for Watsonx.ai Chat Request object.
*
* @author Pablo Sanchidrian Herrera
* @since 1.0.0
*/
// @formatter:off
@JsonInclude(JsonInclude.Include.NON_NULL)
public class WatsonxAiRequest {
public class WatsonxAiChatRequest {

@JsonProperty("input")
private String input;
Expand All @@ -36,19 +42,14 @@ public class WatsonxAiRequest {
@JsonProperty("project_id")
private String projectId = "";

private WatsonxAiRequest(String input, Map<String, Object> parameters, String modelId, String projectId) {
private WatsonxAiChatRequest(String input, Map<String, Object> parameters, String modelId, String projectId) {
this.input = input;
this.parameters = parameters;
this.modelId = modelId;
this.projectId = projectId;
}

public WatsonxAiRequest withModelId(String modelId) {
this.modelId = modelId;
return this;
}

public WatsonxAiRequest withProjectId(String projectId) {
public WatsonxAiChatRequest withProjectId(String projectId) {
this.projectId = projectId;
return this;
}
Expand Down Expand Up @@ -79,8 +80,8 @@ public Builder withParameters(Map<String, Object> parameters) {
return this;
}

public WatsonxAiRequest build() {
return new WatsonxAiRequest(input, parameters, model, "");
public WatsonxAiChatRequest build() {
return new WatsonxAiChatRequest(input, parameters, model, "");
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@
import java.util.List;
import java.util.Map;

/**
* Java class for Watsonx.ai Chat Response object.
*
* @author Pablo Sanchidrian Herrera
* @since 1.0.0
*/
// @formatter:off
@JsonInclude(JsonInclude.Include.NON_NULL)
public record WatsonxAiResponse(
public record WatsonxAiChatResponse(
@JsonProperty("model_id") String modelId,
@JsonProperty("created_at") Date createdAt,
@JsonProperty("results") List<WatsonxAiResults> results,
@JsonProperty("results") List<WatsonxAiChatResults> results,
@JsonProperty("system") Map<String, Object> system
) {}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;

/**
* Java class for Watsonx.ai Chat Results object.
*
* @author Pablo Sanchidrian Herrera
* @since 1.0.0
*/
// @formatter:off
@JsonInclude(JsonInclude.Include.NON_NULL)
public record WatsonxAiResults(
public record WatsonxAiChatResults(
@JsonProperty("generated_text") String generatedText,
@JsonProperty("generated_token_count") Integer generatedTokenCount,
@JsonProperty("input_token_count") Integer inputTokenCount,
Expand Down
Loading