Skip to content

Commit d9eef6f

Browse files
authored
Merge pull request #226 from andreadimaio/main
Add functionality to BAM module
2 parents 5de41ca + 01cd8da commit d9eef6f

23 files changed

+1543
-25
lines changed

bam/deployment/src/main/java/io/quarkiverse/langchain4j/bam/deployment/BamProcessor.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.quarkiverse.langchain4j.bam.deployment;
22

33
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL;
4+
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL;
45

56
import java.util.Optional;
67

@@ -9,7 +10,9 @@
910
import io.quarkiverse.langchain4j.bam.runtime.BamRecorder;
1011
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
1112
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
13+
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
1214
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
15+
import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem;
1316
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
1417
import io.quarkus.deployment.annotations.BuildProducer;
1518
import io.quarkus.deployment.annotations.BuildStep;
@@ -30,19 +33,27 @@ FeatureBuildItem feature() {
3033

3134
@BuildStep
3235
public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem> chatProducer,
36+
BuildProducer<EmbeddingModelProviderCandidateBuildItem> embeddingProducer,
3337
Langchain4jBamBuildConfig config) {
38+
3439
if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) {
3540
chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER));
3641
}
42+
43+
if (config.embeddingModel().enabled().isEmpty() || config.embeddingModel().enabled().get()) {
44+
embeddingProducer.produce(new EmbeddingModelProviderCandidateBuildItem(PROVIDER));
45+
}
3746
}
3847

3948
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
4049
@BuildStep
4150
@Record(ExecutionTime.RUNTIME_INIT)
4251
void generateBeans(BamRecorder recorder,
4352
Optional<SelectedChatModelProviderBuildItem> selectedChatItem,
53+
Optional<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
4454
Langchain4jBamConfig config,
4555
BuildProducer<SyntheticBeanBuildItem> beanProducer) {
56+
4657
if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) {
4758
beanProducer.produce(SyntheticBeanBuildItem
4859
.configure(CHAT_MODEL)
@@ -52,5 +63,17 @@ void generateBeans(BamRecorder recorder,
5263
.supplier(recorder.chatModel(config))
5364
.done());
5465
}
66+
67+
if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) {
68+
beanProducer.produce(
69+
SyntheticBeanBuildItem
70+
.configure(EMBEDDING_MODEL)
71+
.setRuntimeInit()
72+
.defaultBean()
73+
.scope(ApplicationScoped.class)
74+
.supplier(recorder.embeddingModel(config))
75+
.unremovable()
76+
.done());
77+
}
5578
}
5679
}

bam/deployment/src/main/java/io/quarkiverse/langchain4j/bam/deployment/ChatModelBuildConfig.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,9 @@ public interface ChatModelBuildConfig {
1313
*/
1414
@ConfigDocDefault("true")
1515
Optional<Boolean> enabled();
16+
17+
/**
18+
* Embedding model related settings
19+
*/
20+
EmbeddingModelBuildConfig embeddingModel();
1621
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.bam.deployment;
2+
3+
import java.util.Optional;
4+
5+
import io.quarkus.runtime.annotations.ConfigDocDefault;
6+
import io.quarkus.runtime.annotations.ConfigGroup;
7+
8+
@ConfigGroup
9+
public interface EmbeddingModelBuildConfig {
10+
11+
/**
12+
* Whether the model should be enabled
13+
*/
14+
@ConfigDocDefault("true")
15+
Optional<Boolean> enabled();
16+
}

bam/deployment/src/main/java/io/quarkiverse/langchain4j/bam/deployment/Langchain4jBamBuildConfig.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,9 @@ public interface Langchain4jBamBuildConfig {
1313
* Chat model related settings
1414
*/
1515
ChatModelBuildConfig chatModel();
16+
17+
/**
18+
* Embedding model related settings
19+
*/
20+
EmbeddingModelBuildConfig embeddingModel();
1621
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package io.quarkiverse.langchain4j.bam.deployment;
2+
3+
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options;
4+
import static org.junit.jupiter.api.Assertions.assertEquals;
5+
6+
import java.util.List;
7+
8+
import jakarta.inject.Inject;
9+
import jakarta.inject.Singleton;
10+
11+
import org.jboss.shrinkwrap.api.ShrinkWrap;
12+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
13+
import org.junit.jupiter.api.AfterAll;
14+
import org.junit.jupiter.api.BeforeAll;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.RegisterExtension;
17+
18+
import com.fasterxml.jackson.databind.ObjectMapper;
19+
import com.github.tomakehurst.wiremock.WireMockServer;
20+
21+
import dev.langchain4j.service.SystemMessage;
22+
import dev.langchain4j.service.UserMessage;
23+
import io.quarkiverse.langchain4j.RegisterAiService;
24+
import io.quarkiverse.langchain4j.bam.BamRestApi;
25+
import io.quarkiverse.langchain4j.bam.Message;
26+
import io.quarkiverse.langchain4j.bam.Parameters;
27+
import io.quarkiverse.langchain4j.bam.TextGenerationRequest;
28+
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
29+
import io.quarkus.test.QuarkusUnitTest;
30+
31+
public class AiServiceTest {
32+
33+
static WireMockServer wireMockServer;
34+
static ObjectMapper mapper;
35+
static WireMockUtil mockServers;
36+
37+
@RegisterExtension
38+
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
39+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.base-url", WireMockUtil.URL)
40+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.api-key", WireMockUtil.API_KEY)
41+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));
42+
43+
@BeforeAll
44+
static void beforeAll() {
45+
wireMockServer = new WireMockServer(options().port(WireMockUtil.PORT));
46+
wireMockServer.start();
47+
mapper = BamRestApi.objectMapper(new ObjectMapper());
48+
mockServers = new WireMockUtil(wireMockServer);
49+
}
50+
51+
@AfterAll
52+
static void afterAll() {
53+
wireMockServer.stop();
54+
}
55+
56+
@RegisterAiService
57+
@Singleton
58+
interface NewAIService {
59+
60+
@SystemMessage("This is a systemMessage")
61+
@UserMessage("This is a userMessage {text}")
62+
String chat(String text);
63+
}
64+
65+
@Inject
66+
NewAIService service;
67+
68+
@Inject
69+
Langchain4jBamConfig config;
70+
71+
@Test
72+
void chat() throws Exception {
73+
74+
var modelId = config.chatModel().modelId();
75+
76+
var parameters = Parameters.builder()
77+
.decodingMethod(config.chatModel().decodingMethod())
78+
.temperature(config.chatModel().temperature())
79+
.minNewTokens(config.chatModel().minNewTokens())
80+
.maxNewTokens(config.chatModel().maxNewTokens())
81+
.build();
82+
83+
List<Message> messages = List.of(
84+
new Message("system", "This is a systemMessage"),
85+
new Message("user", "This is a userMessage Hello"));
86+
87+
var body = new TextGenerationRequest(modelId, messages, parameters);
88+
89+
mockServers.mockBuilder(200)
90+
.body(mapper.writeValueAsString(body))
91+
.response("""
92+
{
93+
"results": [
94+
{
95+
"generated_token_count": 20,
96+
"input_token_count": 146,
97+
"stop_reason": "max_tokens",
98+
"seed": 40268626,
99+
"generated_text": "AI Response"
100+
}
101+
]
102+
}
103+
""")
104+
.build();
105+
106+
assertEquals("AI Response", service.chat("Hello"));
107+
}
108+
}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package io.quarkiverse.langchain4j.bam.deployment;
2+
3+
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options;
4+
import static org.junit.jupiter.api.Assertions.assertEquals;
5+
6+
import java.time.Duration;
7+
import java.util.List;
8+
9+
import jakarta.inject.Inject;
10+
11+
import org.jboss.shrinkwrap.api.ShrinkWrap;
12+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
13+
import org.junit.jupiter.api.AfterAll;
14+
import org.junit.jupiter.api.BeforeAll;
15+
import org.junit.jupiter.api.Test;
16+
import org.junit.jupiter.api.extension.RegisterExtension;
17+
18+
import com.fasterxml.jackson.databind.ObjectMapper;
19+
import com.github.tomakehurst.wiremock.WireMockServer;
20+
21+
import dev.langchain4j.model.chat.ChatLanguageModel;
22+
import io.quarkiverse.langchain4j.bam.BamRestApi;
23+
import io.quarkiverse.langchain4j.bam.Message;
24+
import io.quarkiverse.langchain4j.bam.Parameters;
25+
import io.quarkiverse.langchain4j.bam.TextGenerationRequest;
26+
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
27+
import io.quarkus.test.QuarkusUnitTest;
28+
29+
public class AllPropertiesTest {
30+
31+
static WireMockServer wireMockServer;
32+
static ObjectMapper mapper;
33+
static WireMockUtil mockServers;
34+
35+
@RegisterExtension
36+
static QuarkusUnitTest unitTest = new QuarkusUnitTest()
37+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.base-url", WireMockUtil.URL)
38+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.api-key", WireMockUtil.API_KEY)
39+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.timeout", "60s")
40+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.log-requests", "true")
41+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.log-responses", "true")
42+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.timeout", "60s")
43+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.version", "aaaa-mm-dd")
44+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.model-id", "my_super_model")
45+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.decoding-method", "greedy")
46+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.include-stop-sequence", "true")
47+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.max-new-tokens", "200")
48+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.min-new-tokens", "10")
49+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.random-seed", "2")
50+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.stop-sequences", "\n,\n\n")
51+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.temperature", "1.5")
52+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.time-limit", "1500")
53+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.top-k", "90")
54+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.top-p", "0.5")
55+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.typical-p", "0.5")
56+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.repetition-penalty", "2.0")
57+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.truncate-input-tokens", "0")
58+
.overrideRuntimeConfigKey("quarkus.langchain4j.bam.chat-model.beam-width", "2")
59+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));
60+
61+
@Inject
62+
Langchain4jBamConfig config;
63+
64+
@Inject
65+
ChatLanguageModel model;
66+
67+
@BeforeAll
68+
static void beforeAll() {
69+
wireMockServer = new WireMockServer(options().port(WireMockUtil.PORT));
70+
wireMockServer.start();
71+
mapper = BamRestApi.objectMapper(new ObjectMapper());
72+
mockServers = new WireMockUtil(wireMockServer);
73+
}
74+
75+
@AfterAll
76+
static void afterAll() {
77+
wireMockServer.stop();
78+
}
79+
80+
@Test
81+
void generate() throws Exception {
82+
83+
assertEquals(WireMockUtil.URL, config.baseUrl().get().toString());
84+
assertEquals(WireMockUtil.API_KEY, config.apiKey());
85+
assertEquals(Duration.ofSeconds(60), config.timeout());
86+
assertEquals(true, config.logRequests());
87+
assertEquals(true, config.logResponses());
88+
assertEquals("aaaa-mm-dd", config.version());
89+
assertEquals("my_super_model", config.chatModel().modelId());
90+
assertEquals("greedy", config.chatModel().decodingMethod());
91+
assertEquals(true, config.chatModel().includeStopSequence().get());
92+
assertEquals(200, config.chatModel().maxNewTokens());
93+
assertEquals(10, config.chatModel().minNewTokens());
94+
assertEquals(2, config.chatModel().randomSeed().get());
95+
assertEquals(List.of("\n", "\n\n"), config.chatModel().stopSequences().get());
96+
assertEquals(1.5, config.chatModel().temperature());
97+
assertEquals(1500, config.chatModel().timeLimit().get());
98+
assertEquals(90, config.chatModel().topK().get());
99+
assertEquals(0.5, config.chatModel().topP().get());
100+
assertEquals(0.5, config.chatModel().typicalP().get());
101+
assertEquals(2.0, config.chatModel().repetitionPenalty().get());
102+
assertEquals(0, config.chatModel().truncateInputTokens().get());
103+
assertEquals(2, config.chatModel().beamWidth().get());
104+
105+
var modelId = config.chatModel().modelId();
106+
107+
var parameters = Parameters.builder()
108+
.minNewTokens(10)
109+
.maxNewTokens(200)
110+
.decodingMethod("greedy")
111+
.includeStopSequence(true)
112+
.randomSeed(2)
113+
.stopSequences(List.of("\n", "\n\n"))
114+
.temperature(1.5)
115+
.timeLimit(1500)
116+
.topK(90)
117+
.topP(0.5)
118+
.typicalP(0.5)
119+
.repetitionPenalty(2.0)
120+
.truncateInputTokens(0)
121+
.beamWidth(2)
122+
.build();
123+
124+
List<Message> messages = List.of(
125+
new Message("user", "Hello how are you?"));
126+
127+
var body = new TextGenerationRequest(modelId, messages, parameters);
128+
129+
mockServers.mockBuilder(200, config.version())
130+
.body(mapper.writeValueAsString(body))
131+
.response("""
132+
{
133+
"id": "05a245ad-1da7-4b9d-9807-ae1733177c1d",
134+
"model_id": "meta-llama/llama-2-70b-chat",
135+
"created_at": "2023-09-01T09:28:29.378Z",
136+
"results": [
137+
{
138+
"generated_token_count": 20,
139+
"input_token_count": 146,
140+
"stop_reason": "max_tokens",
141+
"seed": 40268626,
142+
"generated_text": "Hello! I'm doing well, thanks for asking. I'm here to assist you"
143+
}
144+
],
145+
"conversation_id": "cd3a9bca-b88e-41e4-9d62-bab33098fe39"
146+
}
147+
""")
148+
.build();
149+
150+
assertEquals("Hello! I'm doing well, thanks for asking. I'm here to assist you", model.generate("Hello how are you?"));
151+
}
152+
}

0 commit comments

Comments
 (0)