Skip to content

Commit 8ddd34f

Browse files
andreadimaiojmartisk
authored andcommitted
Enable ScoringModel in watsonx.ai
1 parent aa9ac70 commit 8ddd34f

File tree

17 files changed

+511
-1
lines changed

17 files changed

+511
-1
lines changed

integration-tests/multiple-providers/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@
5151
<artifactId>quarkus-langchain4j-watsonx</artifactId>
5252
<version>${project.version}</version>
5353
</dependency>
54+
<dependency>
55+
<groupId>io.quarkiverse.langchain4j</groupId>
56+
<artifactId>quarkus-langchain4j-cohere</artifactId>
57+
<version>${project.version}</version>
58+
</dependency>
5459
<dependency>
5560
<groupId>io.quarkus</groupId>
5661
<artifactId>quarkus-junit5</artifactId>

integration-tests/multiple-providers/src/main/resources/application.properties

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,10 @@ quarkus.langchain4j.e1.embedding-model.provider=openai
4343
quarkus.langchain4j.openai.e1.api-key=test5
4444
quarkus.langchain4j.e2.embedding-model.provider=ollama
4545
quarkus.langchain4j.e3.embedding-model.provider=dev.langchain4j.model.embedding.onnx.allminilml6v2.AllMiniLmL6V2EmbeddingModel
46+
47+
quarkus.langchain4j.s1.scoring-model.provider=watsonx
48+
quarkus.langchain4j.watsonx.s1.base-url=https://somecluster.somedomain.ai:443/api
49+
quarkus.langchain4j.watsonx.s1.api-key=test
50+
quarkus.langchain4j.watsonx.s1.project-id=proj
51+
quarkus.langchain4j.s2.scoring-model.provider=cohere
52+
quarkus.langchain4j.cohere.s2.api-key=test
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package org.acme.example.multiple;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import jakarta.inject.Inject;
6+
7+
import org.junit.jupiter.api.Test;
8+
9+
import dev.langchain4j.model.scoring.ScoringModel;
10+
import io.quarkiverse.langchain4j.ModelName;
11+
import io.quarkiverse.langchain4j.cohere.runtime.QuarkusCohereScoringModel;
12+
import io.quarkiverse.langchain4j.watsonx.WatsonxRerankModel;
13+
import io.quarkus.arc.ClientProxy;
14+
import io.quarkus.test.junit.QuarkusTest;
15+
16+
@QuarkusTest
17+
public class MultipleScoringModelsTest {
18+
19+
@Inject
20+
@ModelName("s1")
21+
ScoringModel firstNamedModel;
22+
23+
@Inject
24+
@ModelName("s2")
25+
ScoringModel secondNamedModel;
26+
27+
@Test
28+
void firstNamedModel() {
29+
assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(WatsonxRerankModel.class);
30+
}
31+
32+
@Test
33+
void secondNamedModel() {
34+
assertThat(ClientProxy.unwrap(secondNamedModel)).isInstanceOf(QuarkusCohereScoringModel.class);
35+
}
36+
}

model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/LangChain4jWatsonBuildConfig.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,9 @@ public interface LangChain4jWatsonBuildConfig {
1818
* Embedding model related settings.
1919
*/
2020
EmbeddingModelBuildConfig embeddingModel();
21+
22+
/**
23+
* Scoring model related settings.
24+
*/
25+
ScoringModelBuildConfig scoringModel();
2126
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.watsonx.deployment;
2+
3+
import java.util.Optional;
4+
5+
import io.quarkus.runtime.annotations.ConfigDocDefault;
6+
import io.quarkus.runtime.annotations.ConfigGroup;
7+
8+
@ConfigGroup
9+
public interface ScoringModelBuildConfig {
10+
11+
/**
12+
* Whether the scoring model should be enabled.
13+
*/
14+
@ConfigDocDefault("true")
15+
Optional<Boolean> enabled();
16+
}

model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL;
44
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.EMBEDDING_MODEL;
5+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SCORING_MODEL;
56
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.STREAMING_CHAT_MODEL;
67
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.TOKEN_COUNT_ESTIMATOR;
78
import static io.quarkiverse.langchain4j.deployment.TemplateUtil.getTemplateFromAnnotationInstance;
@@ -20,8 +21,10 @@
2021
import io.quarkiverse.langchain4j.deployment.LangChain4jDotNames;
2122
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
2223
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
24+
import io.quarkiverse.langchain4j.deployment.items.ScoringModelProviderCandidateBuildItem;
2325
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
2426
import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem;
27+
import io.quarkiverse.langchain4j.deployment.items.SelectedScoringModelProviderBuildItem;
2528
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
2629
import io.quarkiverse.langchain4j.watsonx.deployment.items.WatsonxChatModelProviderBuildItem;
2730
import io.quarkiverse.langchain4j.watsonx.prompt.PromptFormatter;
@@ -51,6 +54,7 @@ FeatureBuildItem feature() {
5154
@BuildStep
5255
public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem> chatProducer,
5356
BuildProducer<EmbeddingModelProviderCandidateBuildItem> embeddingProducer,
57+
BuildProducer<ScoringModelProviderCandidateBuildItem> scoringProducer,
5458
LangChain4jWatsonBuildConfig config) {
5559

5660
if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) {
@@ -60,6 +64,10 @@ public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem
6064
if (config.embeddingModel().enabled().isEmpty() || config.embeddingModel().enabled().get()) {
6165
embeddingProducer.produce(new EmbeddingModelProviderCandidateBuildItem(PROVIDER));
6266
}
67+
68+
if (config.scoringModel().enabled().isEmpty() || config.scoringModel().enabled().get()) {
69+
scoringProducer.produce(new ScoringModelProviderCandidateBuildItem(PROVIDER));
70+
}
6371
}
6472

6573
@BuildStep
@@ -167,6 +175,7 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
167175
LangChain4jWatsonxFixedRuntimeConfig fixedRuntimeConfig,
168176
List<WatsonxChatModelProviderBuildItem> selectedChatItem,
169177
List<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
178+
List<SelectedScoringModelProviderBuildItem> selectedScoring,
170179
BuildProducer<SyntheticBeanBuildItem> beanProducer) {
171180

172181
for (var selected : selectedChatItem) {
@@ -232,6 +241,20 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
232241
}
233242
}
234243

244+
for (var selected : selectedScoring) {
245+
if (PROVIDER.equals(selected.getProvider())) {
246+
String configName = selected.getConfigName();
247+
var builder = SyntheticBeanBuildItem
248+
.configure(SCORING_MODEL)
249+
.setRuntimeInit()
250+
.defaultBean()
251+
.unremovable()
252+
.scope(ApplicationScoped.class)
253+
.supplier(recorder.scoringModel(runtimeConfig, configName));
254+
addQualifierIfNecessary(builder, configName);
255+
beanProducer.produce(builder.done());
256+
}
257+
}
235258
}
236259

237260
private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String configName) {

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationAllPropertiesTest.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,20 @@
2020

2121
import dev.langchain4j.data.embedding.Embedding;
2222
import dev.langchain4j.data.message.AiMessage;
23+
import dev.langchain4j.data.segment.TextSegment;
2324
import dev.langchain4j.model.chat.ChatLanguageModel;
2425
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
2526
import dev.langchain4j.model.chat.TokenCountEstimator;
2627
import dev.langchain4j.model.embedding.EmbeddingModel;
2728
import dev.langchain4j.model.output.Response;
29+
import dev.langchain4j.model.scoring.ScoringModel;
2830
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingParameters;
2931
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
3032
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters;
3133
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters.LengthPenalty;
3234
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest;
35+
import io.quarkiverse.langchain4j.watsonx.bean.TextRerankParameters;
36+
import io.quarkiverse.langchain4j.watsonx.bean.TextRerankRequest;
3337
import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest;
3438
import io.quarkus.test.QuarkusUnitTest;
3539

@@ -66,6 +70,8 @@ public class GenerationAllPropertiesTest extends WireMockAbstract {
6670
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.chat-model.include-stop-sequence", "false")
6771
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.model-id", "my_super_embedding_model")
6872
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.embedding-model.truncate-input-tokens", "10")
73+
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.scoring-model.model-id", "my_super_scoring_model")
74+
.overrideRuntimeConfigKey("quarkus.langchain4j.watsonx.scoring-model.truncate-input-tokens", "10")
6975
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));
7076

7177
@Override
@@ -85,6 +91,9 @@ void handlerBeforeEach() {
8591
@Inject
8692
EmbeddingModel embeddingModel;
8793

94+
@Inject
95+
ScoringModel scoringModel;
96+
8897
@Inject
8998
TokenCountEstimator tokenCountEstimator;
9099

@@ -106,6 +115,8 @@ void handlerBeforeEach() {
106115

107116
static EmbeddingParameters embeddingParameters = new EmbeddingParameters(10);
108117

118+
static TextRerankParameters scoringParameters = new TextRerankParameters(10);
119+
109120
@Test
110121
void check_config() throws Exception {
111122
var runtimeConfig = langchain4jWatsonConfig.defaultConfig();
@@ -138,6 +149,8 @@ void check_config() throws Exception {
138149
assertEquals(true, fixedRuntimeConfig.chatModel().promptFormatter());
139150
assertEquals("my_super_embedding_model", runtimeConfig.embeddingModel().modelId());
140151
assertEquals(10, runtimeConfig.embeddingModel().truncateInputTokens().orElse(null));
152+
assertEquals("my_super_scoring_model", runtimeConfig.scoringModel().modelId());
153+
assertEquals(10, runtimeConfig.scoringModel().truncateInputTokens().orElse(null));
141154
}
142155

143156
@Test
@@ -175,6 +188,48 @@ void check_embedding_model() throws Exception {
175188
assertNotNull(response.content());
176189
}
177190

191+
@Test
192+
void check_scoring_model() throws Exception {
193+
var config = langchain4jWatsonConfig.defaultConfig();
194+
String modelId = config.scoringModel().modelId();
195+
String projectId = config.projectId();
196+
197+
var segments = List.of(
198+
TextSegment.from(
199+
"The novel 'Moby-Dick' was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil.\""),
200+
TextSegment.from(
201+
"Harper Lee, an American novelist widely known for her novel 'To Kill a Mockingbird', was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961."),
202+
TextSegment.from(
203+
"Jane Austen was an English novelist known primarily for her six major novels, which interpret, critique and comment upon the British landed gentry at the end of the 18th century."),
204+
TextSegment.from(
205+
"The 'Harry Potter' series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era."),
206+
TextSegment.from(
207+
"'The Great Gatsby', a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."),
208+
TextSegment.from(
209+
"'The Great Gatsby', a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."),
210+
TextSegment.from(
211+
"To Kill a Mockingbird' is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature."));
212+
213+
TextRerankRequest request = TextRerankRequest.of(modelId, projectId, "Who wrote 'To Kill a Mockingbird'?",
214+
segments, scoringParameters);
215+
216+
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_SCORING_API, 200, "aaaa-mm-dd")
217+
.body(mapper.writeValueAsString(request))
218+
.response(WireMockUtil.RESPONSE_WATSONX_SCORING_API.formatted(modelId))
219+
.build();
220+
221+
Response<List<Double>> response = scoringModel.scoreAll(segments, "Who wrote 'To Kill a Mockingbird'?");
222+
assertNotNull(response);
223+
assertEquals(6, response.content().size());
224+
assertEquals(318, response.tokenUsage().inputTokenCount());
225+
assertEquals(-2.5847978591918945, response.content().get(0));
226+
assertEquals(8.770895957946777, response.content().get(1));
227+
assertEquals(-4.939967155456543, response.content().get(2));
228+
assertEquals(-3.349348306655884, response.content().get(3));
229+
assertEquals(-3.920926570892334, response.content().get(4));
230+
assertEquals(9.720501899719238, response.content().get(5));
231+
}
232+
178233
@Test
179234
void check_token_count_estimator() throws Exception {
180235
var config = langchain4jWatsonConfig.defaultConfig();

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/GenerationDefaultPropertiesTest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,17 @@
2222

2323
import dev.langchain4j.data.embedding.Embedding;
2424
import dev.langchain4j.data.message.AiMessage;
25+
import dev.langchain4j.data.segment.TextSegment;
2526
import dev.langchain4j.model.chat.ChatLanguageModel;
2627
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
2728
import dev.langchain4j.model.chat.TokenCountEstimator;
2829
import dev.langchain4j.model.embedding.EmbeddingModel;
2930
import dev.langchain4j.model.output.Response;
31+
import dev.langchain4j.model.scoring.ScoringModel;
3032
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
3133
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters;
3234
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest;
35+
import io.quarkiverse.langchain4j.watsonx.bean.TextRerankRequest;
3336
import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest;
3437
import io.quarkus.test.QuarkusUnitTest;
3538

@@ -68,6 +71,9 @@ void handlerBeforeEach() {
6871
@Inject
6972
EmbeddingModel embeddingModel;
7073

74+
@Inject
75+
ScoringModel scoringModel;
76+
7177
@Inject
7278
TokenCountEstimator tokenCountEstimator;
7379

@@ -99,6 +105,8 @@ void check_config() throws Exception {
99105
assertEquals("urn:ibm:params:oauth:grant-type:apikey", runtimeConfig.iam().grantType());
100106
assertEquals(WireMockUtil.DEFAULT_EMBEDDING_MODEL, runtimeConfig.embeddingModel().modelId());
101107
assertTrue(runtimeConfig.embeddingModel().truncateInputTokens().isEmpty());
108+
assertEquals(WireMockUtil.DEFAULT_SCORING_MODEL, runtimeConfig.scoringModel().modelId());
109+
assertTrue(runtimeConfig.scoringModel().truncateInputTokens().isEmpty());
102110
}
103111

104112
@Test
@@ -118,6 +126,48 @@ void check_chat_model_config() throws Exception {
118126
dev.langchain4j.data.message.UserMessage.from("UserMessage")).content().text());
119127
}
120128

129+
@Test
130+
void check_scoring_model() throws Exception {
131+
var config = langchain4jWatsonConfig.defaultConfig();
132+
String modelId = config.scoringModel().modelId();
133+
String projectId = config.projectId();
134+
135+
var segments = List.of(
136+
TextSegment.from(
137+
"The novel 'Moby-Dick' was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil.\""),
138+
TextSegment.from(
139+
"Harper Lee, an American novelist widely known for her novel 'To Kill a Mockingbird', was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961."),
140+
TextSegment.from(
141+
"Jane Austen was an English novelist known primarily for her six major novels, which interpret, critique and comment upon the British landed gentry at the end of the 18th century."),
142+
TextSegment.from(
143+
"The 'Harry Potter' series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era."),
144+
TextSegment.from(
145+
"'The Great Gatsby', a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."),
146+
TextSegment.from(
147+
"'The Great Gatsby', a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."),
148+
TextSegment.from(
149+
"To Kill a Mockingbird' is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature."));
150+
151+
TextRerankRequest request = TextRerankRequest.of(modelId, projectId, "Who wrote 'To Kill a Mockingbird'?",
152+
segments, null);
153+
154+
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_SCORING_API, 200)
155+
.body(mapper.writeValueAsString(request))
156+
.response(WireMockUtil.RESPONSE_WATSONX_SCORING_API.formatted(modelId))
157+
.build();
158+
159+
Response<List<Double>> response = scoringModel.scoreAll(segments, "Who wrote 'To Kill a Mockingbird'?");
160+
assertNotNull(response);
161+
assertEquals(6, response.content().size());
162+
assertEquals(318, response.tokenUsage().inputTokenCount());
163+
assertEquals(-2.5847978591918945, response.content().get(0));
164+
assertEquals(8.770895957946777, response.content().get(1));
165+
assertEquals(-4.939967155456543, response.content().get(2));
166+
assertEquals(-3.349348306655884, response.content().get(3));
167+
assertEquals(-3.920926570892334, response.content().get(4));
168+
assertEquals(9.720501899719238, response.content().get(5));
169+
}
170+
121171
@Test
122172
void check_embedding_model() throws Exception {
123173
var config = langchain4jWatsonConfig.defaultConfig();

0 commit comments

Comments
 (0)