Skip to content

Commit f87f112

Browse files
committed
[watsonx.ai] Refactor WatsonxChatModel/WatsonxStreamingChatModel to WatsonxGenerationModel
1 parent 03099f7 commit f87f112

File tree

20 files changed

+611
-761
lines changed

20 files changed

+611
-761
lines changed

integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import io.quarkiverse.langchain4j.huggingface.QuarkusHuggingFaceChatModel;
1414
import io.quarkiverse.langchain4j.ollama.OllamaChatLanguageModel;
1515
import io.quarkiverse.langchain4j.openshiftai.OpenshiftAiChatModel;
16-
import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel;
16+
import io.quarkiverse.langchain4j.watsonx.WatsonxGenerationModel;
1717
import io.quarkus.arc.ClientProxy;
1818
import io.quarkus.test.junit.QuarkusTest;
1919

@@ -79,6 +79,6 @@ void sixthNamedModel() {
7979

8080
@Test
8181
void seventhNamedModel() {
82-
assertThat(ClientProxy.unwrap(seventhNamedModel)).isInstanceOf(WatsonxChatModel.class);
82+
assertThat(ClientProxy.unwrap(seventhNamedModel)).isInstanceOf(WatsonxGenerationModel.class);
8383
}
8484
}

integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleTokenCountEstimatorProvidersTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import dev.langchain4j.model.chat.TokenCountEstimator;
1111
import io.quarkiverse.langchain4j.ModelName;
1212
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiChatModel;
13-
import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel;
13+
import io.quarkiverse.langchain4j.watsonx.WatsonxGenerationModel;
1414
import io.quarkus.arc.ClientProxy;
1515
import io.quarkus.test.junit.QuarkusTest;
1616

@@ -41,7 +41,7 @@ void azureOpenAiTest() {
4141

4242
@Test
4343
void watsonxTest() {
44-
assertThat(ClientProxy.unwrap(watsonxChat)).isInstanceOf(WatsonxChatModel.class);
45-
assertThat(ClientProxy.unwrap(watsonxTokenizer)).isInstanceOf(WatsonxChatModel.class);
44+
assertThat(ClientProxy.unwrap(watsonxChat)).isInstanceOf(WatsonxGenerationModel.class);
45+
assertThat(ClientProxy.unwrap(watsonxTokenizer)).isInstanceOf(WatsonxGenerationModel.class);
4646
}
4747
}

model-providers/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonxProcessor.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,17 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
154154
String configName = selected.getConfigName();
155155
PromptFormatter promptFormatter = selected.getPromptFormatter();
156156

157-
var chatModel = recorder.chatModel(runtimeConfig, fixedRuntimeConfig, configName, promptFormatter);
157+
var chatLanguageModel = recorder.generationModel(runtimeConfig, fixedRuntimeConfig, configName, promptFormatter);
158+
var streamingChatLanguageModel = recorder.generationStreamingModel(runtimeConfig, fixedRuntimeConfig, configName,
159+
promptFormatter);
160+
158161
var chatBuilder = SyntheticBeanBuildItem
159162
.configure(CHAT_MODEL)
160163
.setRuntimeInit()
161164
.defaultBean()
162165
.scope(ApplicationScoped.class)
163-
.supplier(chatModel);
166+
.supplier(chatLanguageModel);
167+
164168
addQualifierIfNecessary(chatBuilder, configName);
165169
beanProducer.produce(chatBuilder.done());
166170

@@ -169,7 +173,8 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
169173
.setRuntimeInit()
170174
.defaultBean()
171175
.scope(ApplicationScoped.class)
172-
.supplier(chatModel);
176+
.supplier(chatLanguageModel);
177+
173178
addQualifierIfNecessary(tokenizerBuilder, configName);
174179
beanProducer.produce(tokenizerBuilder.done());
175180

@@ -178,8 +183,8 @@ void generateBeans(WatsonxRecorder recorder, LangChain4jWatsonxConfig runtimeCon
178183
.setRuntimeInit()
179184
.defaultBean()
180185
.scope(ApplicationScoped.class)
181-
.supplier(recorder.streamingChatModel(runtimeConfig, fixedRuntimeConfig, configName,
182-
promptFormatter));
186+
.supplier(streamingChatLanguageModel);
187+
183188
addQualifierIfNecessary(streamingBuilder, configName);
184189
beanProducer.produce(streamingBuilder.done());
185190
}

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ void chat() throws Exception {
6060
.temperature(chatModelConfig.temperature())
6161
.minNewTokens(chatModelConfig.minNewTokens())
6262
.maxNewTokens(chatModelConfig.maxNewTokens())
63+
.timeLimit(10000L)
6364
.build();
6465

6566
TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, input, parameters);

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AllPropertiesTest.java

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import static org.awaitility.Awaitility.await;
55
import static org.junit.jupiter.api.Assertions.assertEquals;
66
import static org.junit.jupiter.api.Assertions.assertNotNull;
7-
import static org.junit.jupiter.api.Assertions.assertTrue;
87

98
import java.time.Duration;
109
import java.util.Date;
@@ -26,15 +25,11 @@
2625
import dev.langchain4j.model.chat.TokenCountEstimator;
2726
import dev.langchain4j.model.embedding.EmbeddingModel;
2827
import dev.langchain4j.model.output.Response;
29-
import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel;
30-
import io.quarkiverse.langchain4j.watsonx.WatsonxStreamingChatModel;
3128
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
3229
import io.quarkiverse.langchain4j.watsonx.bean.Parameters;
3330
import io.quarkiverse.langchain4j.watsonx.bean.Parameters.LengthPenalty;
3431
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest;
3532
import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest;
36-
import io.quarkiverse.langchain4j.watsonx.prompt.impl.NoopPromptFormatter;
37-
import io.quarkus.arc.ClientProxy;
3833
import io.quarkus.test.QuarkusUnitTest;
3934

4035
public class AllPropertiesTest extends WireMockAbstract {
@@ -106,15 +101,6 @@ void handlerBeforeEach() {
106101
.includeStopSequence(false)
107102
.build();
108103

109-
@Test
110-
void prompt_formatter() {
111-
var unwrapChatModel = (WatsonxChatModel) ClientProxy.unwrap(chatModel);
112-
assertTrue(unwrapChatModel.getPromptFormatter() instanceof NoopPromptFormatter);
113-
114-
var unwrapStreamingChatModel = (WatsonxStreamingChatModel) ClientProxy.unwrap(streamingChatModel);
115-
assertTrue(unwrapStreamingChatModel.getPromptFormatter() instanceof NoopPromptFormatter);
116-
}
117-
118104
@Test
119105
void check_config() throws Exception {
120106
var runtimeConfig = langchain4jWatsonConfig.defaultConfig();

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/ChatMemoryPlaceholderTest.java

Lines changed: 30 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -97,26 +97,16 @@ public interface NoMemoryAiService {
9797
@Test
9898
void extract_dialogue_test() throws Exception {
9999

100-
LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig();
101-
ChatModelConfig chatModelConfig = watsonConfig.chatModel();
102-
String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId();
103100
String chatMemoryId = "userId";
104-
String projectId = watsonConfig.projectId();
105-
Parameters parameters = Parameters.builder()
106-
.decodingMethod(chatModelConfig.decodingMethod())
107-
.temperature(chatModelConfig.temperature())
108-
.minNewTokens(chatModelConfig.minNewTokens())
109-
.maxNewTokens(chatModelConfig.maxNewTokens())
110-
.build();
111101

112102
var input = """
113103
You are a helpful assistant
114104
Context:
115105
116106
Hello""";
117-
var body = new TextGenerationRequest(modelId, projectId, input, parameters);
107+
118108
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200)
119-
.body(mapper.writeValueAsString(body))
109+
.body(mapper.writeValueAsString(createRequest(input)))
120110
.response("""
121111
{
122112
"results": [
@@ -140,9 +130,9 @@ void extract_dialogue_test() throws Exception {
140130
Hello
141131
Hi!
142132
What is your name?""";
143-
body = new TextGenerationRequest(modelId, projectId, input, parameters);
133+
144134
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200)
145-
.body(mapper.writeValueAsString(body))
135+
.body(mapper.writeValueAsString(createRequest(input)))
146136
.response("""
147137
{
148138
"results": [
@@ -162,26 +152,16 @@ void extract_dialogue_test() throws Exception {
162152
@Test
163153
void extract_dialogue_with_delimiter_test() throws Exception {
164154

165-
LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig();
166-
ChatModelConfig chatModelConfig = watsonConfig.chatModel();
167-
String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId();
168155
String chatMemoryId = "userId_with_delimiter";
169-
String projectId = watsonConfig.projectId();
170-
Parameters parameters = Parameters.builder()
171-
.decodingMethod(chatModelConfig.decodingMethod())
172-
.temperature(chatModelConfig.temperature())
173-
.minNewTokens(chatModelConfig.minNewTokens())
174-
.maxNewTokens(chatModelConfig.maxNewTokens())
175-
.build();
176156

177157
var input = """
178158
You are a helpful assistant
179159
Context:
180160
181161
Hello""";
182-
var body = new TextGenerationRequest(modelId, projectId, input, parameters);
162+
183163
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200)
184-
.body(mapper.writeValueAsString(body))
164+
.body(mapper.writeValueAsString(createRequest(input)))
185165
.response("""
186166
{
187167
"results": [
@@ -204,9 +184,9 @@ void extract_dialogue_with_delimiter_test() throws Exception {
204184
Hello
205185
Hi!
206186
What is your name?""";
207-
body = new TextGenerationRequest(modelId, projectId, input, parameters);
187+
208188
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200)
209-
.body(mapper.writeValueAsString(body))
189+
.body(mapper.writeValueAsString(createRequest(input)))
210190
.response("""
211191
{
212192
"results": [
@@ -226,26 +206,16 @@ void extract_dialogue_with_delimiter_test() throws Exception {
226206
@Test
227207
void extract_dialogue_with_all_params_test() throws Exception {
228208

229-
LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig();
230-
ChatModelConfig chatModelConfig = watsonConfig.chatModel();
231-
String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId();
232209
String chatMemoryId = "userId_with_all_params";
233-
String projectId = watsonConfig.projectId();
234-
Parameters parameters = Parameters.builder()
235-
.decodingMethod(chatModelConfig.decodingMethod())
236-
.temperature(chatModelConfig.temperature())
237-
.minNewTokens(chatModelConfig.minNewTokens())
238-
.maxNewTokens(chatModelConfig.maxNewTokens())
239-
.build();
240210

241211
var input = """
242212
You are a helpful assistant
243213
Context:
244214
245215
Hello""";
246-
var body = new TextGenerationRequest(modelId, projectId, input, parameters);
216+
247217
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200)
248-
.body(mapper.writeValueAsString(body))
218+
.body(mapper.writeValueAsString(createRequest(input)))
249219
.response("""
250220
{
251221
"results": [
@@ -268,9 +238,9 @@ void extract_dialogue_with_all_params_test() throws Exception {
268238
Hello
269239
Hi!
270240
What is your name?""";
271-
body = new TextGenerationRequest(modelId, projectId, input, parameters);
241+
272242
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200)
273-
.body(mapper.writeValueAsString(body))
243+
.body(mapper.writeValueAsString(createRequest(input)))
274244
.response("""
275245
{
276246
"results": [
@@ -290,17 +260,7 @@ void extract_dialogue_with_all_params_test() throws Exception {
290260
@Test
291261
void extract_dialogue_no_memory_test() throws Exception {
292262

293-
LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig();
294-
ChatModelConfig chatModelConfig = watsonConfig.chatModel();
295-
String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId();
296263
String chatMemoryId = "userId_with_all_params";
297-
String projectId = watsonConfig.projectId();
298-
Parameters parameters = Parameters.builder()
299-
.decodingMethod(chatModelConfig.decodingMethod())
300-
.temperature(chatModelConfig.temperature())
301-
.minNewTokens(chatModelConfig.minNewTokens())
302-
.maxNewTokens(chatModelConfig.maxNewTokens())
303-
.build();
304264

305265
var input = """
306266
Context:
@@ -309,9 +269,9 @@ void extract_dialogue_no_memory_test() throws Exception {
309269
User: What is your name?
310270
Assistant: My name is AiBot
311271
Hello""";
312-
var body = new TextGenerationRequest(modelId, projectId, input, parameters);
272+
313273
mockServers.mockWatsonxBuilder(WireMockUtil.URL_WATSONX_CHAT_API, 200)
314-
.body(mapper.writeValueAsString(body))
274+
.body(mapper.writeValueAsString(createRequest(input)))
315275
.response("""
316276
{
317277
"results": [
@@ -327,4 +287,20 @@ void extract_dialogue_no_memory_test() throws Exception {
327287

328288
noMemoryAiService.rephrase(chatMemoryStore.getMessages(chatMemoryId), "Hello");
329289
}
290+
291+
private TextGenerationRequest createRequest(String input) {
292+
LangChain4jWatsonxConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig();
293+
ChatModelConfig chatModelConfig = watsonConfig.chatModel();
294+
String modelId = langchain4jWatsonFixedRuntimeConfig.defaultConfig().chatModel().modelId();
295+
String projectId = watsonConfig.projectId();
296+
Parameters parameters = Parameters.builder()
297+
.decodingMethod(chatModelConfig.decodingMethod())
298+
.temperature(chatModelConfig.temperature())
299+
.minNewTokens(chatModelConfig.minNewTokens())
300+
.maxNewTokens(chatModelConfig.maxNewTokens())
301+
.timeLimit(10000L)
302+
.build();
303+
304+
return new TextGenerationRequest(modelId, projectId, input, parameters);
305+
}
330306
}

model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/DefaultPropertiesTest.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,10 @@
2727
import dev.langchain4j.model.chat.TokenCountEstimator;
2828
import dev.langchain4j.model.embedding.EmbeddingModel;
2929
import dev.langchain4j.model.output.Response;
30-
import io.quarkiverse.langchain4j.watsonx.WatsonxChatModel;
3130
import io.quarkiverse.langchain4j.watsonx.bean.EmbeddingRequest;
3231
import io.quarkiverse.langchain4j.watsonx.bean.Parameters;
3332
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest;
3433
import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest;
35-
import io.quarkiverse.langchain4j.watsonx.prompt.impl.NoopPromptFormatter;
36-
import io.quarkus.arc.ClientProxy;
3734
import io.quarkus.test.QuarkusUnitTest;
3835

3936
public class DefaultPropertiesTest extends WireMockAbstract {
@@ -73,12 +70,6 @@ void handlerBeforeEach() {
7370
@Inject
7471
TokenCountEstimator tokenCountEstimator;
7572

76-
@Test
77-
void prompt_formatter() {
78-
var unwrapChatModel = (WatsonxChatModel) ClientProxy.unwrap(chatModel);
79-
assertTrue(unwrapChatModel.getPromptFormatter() instanceof NoopPromptFormatter);
80-
}
81-
8273
@Test
8374
void check_config() throws Exception {
8475
var runtimeConfig = langchain4jWatsonConfig.defaultConfig();

0 commit comments

Comments
 (0)