Skip to content

Commit 9cd3448

Browse files
andreadimaiojmartisk
authored andcommitted
Enable ScoringModel in watsonx.ai
1 parent 8ddd34f commit 9cd3448

File tree

17 files changed

+228
-432
lines changed

17 files changed

+228
-432
lines changed

integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleScoringModelsTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import dev.langchain4j.model.scoring.ScoringModel;
1010
import io.quarkiverse.langchain4j.ModelName;
1111
import io.quarkiverse.langchain4j.cohere.runtime.QuarkusCohereScoringModel;
12-
import io.quarkiverse.langchain4j.watsonx.WatsonxRerankModel;
12+
import io.quarkiverse.langchain4j.watsonx.WatsonxScoringModel;
1313
import io.quarkus.arc.ClientProxy;
1414
import io.quarkus.test.junit.QuarkusTest;
1515

@@ -26,7 +26,7 @@ public class MultipleScoringModelsTest {
2626

2727
@Test
2828
void firstNamedModel() {
29-
assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(WatsonxRerankModel.class);
29+
assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(WatsonxScoringModel.class);
3030
}
3131

3232
@Test

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929
import dev.langchain4j.model.scoring.ScoringModel;
3030
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters;
3131
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
32+
import io.quarkiverse.langchain4j.watsonx.bean.ScoringParameters;
33+
import io.quarkiverse.langchain4j.watsonx.bean.ScoringRequest;
3234
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters;
3335
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters.LengthPenalty;
3436
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest;
35-
import io.quarkiverse.langchain4j.watsonx.bean.TextRerankParameters;
36-
import io.quarkiverse.langchain4j.watsonx.bean.TextRerankRequest;
3737
import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest;
3838
import io.quarkus.test.QuarkusUnitTest;
3939

@@ -115,7 +115,7 @@ void handlerBeforeEach() {
115115

116116
static EmbeddingParameters embeddingParameters = new EmbeddingParameters(10);
117117

118-
static TextRerankParameters scoringParameters = new TextRerankParameters(10);
118+
static ScoringParameters scoringParameters = new ScoringParameters(10);
119119

120120
@Test
121121
void check_config() throws Exception {
@@ -210,7 +210,7 @@ void check_scoring_model() throws Exception {
210210
TextSegment.from(
211211
"To Kill a Mockingbird' is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature."));
212212

213-
TextRerankRequest request = TextRerankRequest.of(modelId, projectId, "Who wrote 'To Kill a Mockingbird'?",
213+
ScoringRequest request = ScoringRequest.of(modelId, projectId, "Who wrote 'To Kill a Mockingbird'?",
214214
segments, scoringParameters);
215215

216216
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_SCORING_API, 200, "aaaa-mm-dd")

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationDefaultPropertiesTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
import dev.langchain4j.model.output.Response;
3131
import dev.langchain4j.model.scoring.ScoringModel;
3232
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
33+
import io.quarkiverse.langchain4j.watsonx.bean.ScoringRequest;
3334
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters;
3435
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest;
35-
import io.quarkiverse.langchain4j.watsonx.bean.TextRerankRequest;
3636
import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest;
3737
import io.quarkus.test.QuarkusUnitTest;
3838

@@ -148,7 +148,7 @@ void check_scoring_model() throws Exception {
148148
TextSegment.from(
149149
"To Kill a Mockingbird' is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature."));
150150

151-
TextRerankRequest request = TextRerankRequest.of(modelId, projectId, "Who wrote 'To Kill a Mockingbird'?",
151+
ScoringRequest request = ScoringRequest.of(modelId, projectId, "Who wrote 'To Kill a Mockingbird'?",
152152
segments, null);
153153

154154
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_SCORING_API, 200)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package io.quarkiverse.langchain4j.watsonx;
2+
3+
import java.net.URL;
4+
import java.time.Duration;
5+
import java.util.concurrent.TimeUnit;
6+
7+
import org.jboss.resteasy.reactive.client.api.LoggingScope;
8+
9+
import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi;
10+
import io.quarkiverse.langchain4j.watsonx.client.filter.BearerTokenHeaderFactory;
11+
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
12+
13+
public abstract class Watsonx {
14+
15+
protected final String modelId, projectId, version;
16+
protected final WatsonxRestApi client;
17+
18+
public Watsonx(Builder<?> builder) {
19+
QuarkusRestClientBuilder restClientBuilder = QuarkusRestClientBuilder.newBuilder()
20+
.baseUrl(builder.url)
21+
.clientHeadersFactory(new BearerTokenHeaderFactory(builder.tokenGenerator))
22+
.connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS)
23+
.readTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS);
24+
25+
if (builder.logRequests || builder.logResponses) {
26+
restClientBuilder.loggingScope(LoggingScope.REQUEST_RESPONSE);
27+
restClientBuilder.clientLogger(new WatsonxRestApi.WatsonClientLogger(
28+
builder.logRequests,
29+
builder.logResponses));
30+
}
31+
32+
this.client = restClientBuilder.build(WatsonxRestApi.class);
33+
this.modelId = builder.modelId;
34+
this.projectId = builder.projectId;
35+
this.version = builder.version;
36+
}
37+
38+
public WatsonxRestApi getClient() {
39+
return client;
40+
}
41+
42+
@SuppressWarnings("unchecked")
43+
public static abstract class Builder<T extends Builder<T>> {
44+
45+
protected String modelId;
46+
protected String version;
47+
protected String projectId;
48+
protected Duration timeout;
49+
protected URL url;
50+
protected boolean logResponses;
51+
protected boolean logRequests;
52+
protected WatsonxTokenGenerator tokenGenerator;
53+
54+
public T modelId(String modelId) {
55+
this.modelId = modelId;
56+
return (T) this;
57+
}
58+
59+
public T version(String version) {
60+
this.version = version;
61+
return (T) this;
62+
}
63+
64+
public T projectId(String projectId) {
65+
this.projectId = projectId;
66+
return (T) this;
67+
}
68+
69+
public T url(URL url) {
70+
this.url = url;
71+
return (T) this;
72+
}
73+
74+
public T timeout(Duration timeout) {
75+
this.timeout = timeout;
76+
return (T) this;
77+
}
78+
79+
public T tokenGenerator(WatsonxTokenGenerator tokenGenerator) {
80+
this.tokenGenerator = tokenGenerator;
81+
return (T) this;
82+
}
83+
84+
public T logRequests(boolean logRequests) {
85+
this.logRequests = logRequests;
86+
return (T) this;
87+
}
88+
89+
public T logResponses(boolean logResponses) {
90+
this.logResponses = logResponses;
91+
return (T) this;
92+
}
93+
}
94+
}

model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java

Lines changed: 3 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,12 @@
22

33
import static io.quarkiverse.langchain4j.watsonx.WatsonxUtils.retryOn;
44

5-
import java.net.URL;
6-
import java.time.Duration;
75
import java.util.ArrayList;
86
import java.util.List;
97
import java.util.concurrent.Callable;
10-
import java.util.concurrent.TimeUnit;
118
import java.util.function.Consumer;
129
import java.util.stream.Collectors;
1310

14-
import org.jboss.resteasy.reactive.client.api.LoggingScope;
15-
1611
import dev.langchain4j.agent.tool.ToolExecutionRequest;
1712
import dev.langchain4j.agent.tool.ToolSpecification;
1813
import dev.langchain4j.data.message.AiMessage;
@@ -36,43 +31,21 @@
3631
import io.quarkiverse.langchain4j.watsonx.bean.TextChatResponse.TextChatUsage;
3732
import io.quarkiverse.langchain4j.watsonx.bean.TextStreamingChatResponse;
3833
import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest;
39-
import io.quarkiverse.langchain4j.watsonx.client.WatsonxRestApi;
40-
import io.quarkiverse.langchain4j.watsonx.client.filter.BearerTokenHeaderFactory;
41-
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
4234
import io.smallrye.mutiny.Context;
4335
import io.smallrye.mutiny.infrastructure.Infrastructure;
4436

45-
public class WatsonxChatModel implements ChatLanguageModel, StreamingChatLanguageModel, TokenCountEstimator {
37+
public class WatsonxChatModel extends Watsonx implements ChatLanguageModel, StreamingChatLanguageModel, TokenCountEstimator {
4638

4739
private static final String USAGE_CONTEXT = "USAGE";
4840
private static final String FINISH_REASON_CONTEXT = "FINISH_REASON";
4941
private static final String ROLE_CONTEXT = "ROLE";
5042
private static final String TOOLS_CONTEXT = "TOOLS";
5143
private static final String COMPLETE_MESSAGE_CONTEXT = "COMPLETE_MESSAGE";
5244

53-
private final String modelId, projectId, version;
54-
private final WatsonxRestApi client;
5545
private final TextChatParameters parameters;
5646

5747
public WatsonxChatModel(Builder builder) {
58-
59-
QuarkusRestClientBuilder restClientBuilder = QuarkusRestClientBuilder.newBuilder()
60-
.baseUrl(builder.url)
61-
.clientHeadersFactory(new BearerTokenHeaderFactory(builder.tokenGenerator))
62-
.connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS)
63-
.readTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS);
64-
65-
if (builder.logRequests || builder.logResponses) {
66-
restClientBuilder.loggingScope(LoggingScope.REQUEST_RESPONSE);
67-
restClientBuilder.clientLogger(new WatsonxRestApi.WatsonClientLogger(
68-
builder.logRequests,
69-
builder.logResponses));
70-
}
71-
72-
this.client = restClientBuilder.build(WatsonxRestApi.class);
73-
this.modelId = builder.modelId;
74-
this.projectId = builder.projectId;
75-
this.version = builder.version;
48+
super(builder);
7649

7750
this.parameters = TextChatParameters.builder()
7851
.maxTokens(builder.maxTokens)
@@ -298,45 +271,12 @@ private FinishReason toFinishReason(String reason) {
298271
};
299272
}
300273

301-
public static final class Builder {
274+
public static final class Builder extends Watsonx.Builder<Builder> {
302275

303-
private String modelId;
304-
private String version;
305-
private String projectId;
306-
private Duration timeout;
307276
private Integer maxTokens;
308277
private Double temperature;
309278
private Double topP;
310279
private String responseFormat;
311-
private URL url;
312-
public boolean logResponses;
313-
public boolean logRequests;
314-
private WatsonxTokenGenerator tokenGenerator;
315-
316-
public Builder modelId(String modelId) {
317-
this.modelId = modelId;
318-
return this;
319-
}
320-
321-
public Builder version(String version) {
322-
this.version = version;
323-
return this;
324-
}
325-
326-
public Builder projectId(String projectId) {
327-
this.projectId = projectId;
328-
return this;
329-
}
330-
331-
public Builder url(URL url) {
332-
this.url = url;
333-
return this;
334-
}
335-
336-
public Builder timeout(Duration timeout) {
337-
this.timeout = timeout;
338-
return this;
339-
}
340280

341281
public Builder maxTokens(Integer maxTokens) {
342282
this.maxTokens = maxTokens;
@@ -358,21 +298,6 @@ public Builder responseFormat(String responseFormat) {
358298
return this;
359299
}
360300

361-
public Builder tokenGenerator(WatsonxTokenGenerator tokenGenerator) {
362-
this.tokenGenerator = tokenGenerator;
363-
return this;
364-
}
365-
366-
public Builder logRequests(boolean logRequests) {
367-
this.logRequests = logRequests;
368-
return this;
369-
}
370-
371-
public Builder logResponses(boolean logResponses) {
372-
this.logResponses = logResponses;
373-
return this;
374-
}
375-
376301
public WatsonxChatModel build() {
377302
return new WatsonxChatModel(this);
378303
}

0 commit comments

Comments
 (0)