Skip to content

Commit 01704ed

Browse files
authored
Merge pull request #1366 from holomekc/main
feat: Add Amazon Bedrock
2 parents 02fe41e + 2aa2766 commit 01704ed

38 files changed

+11604
-0
lines changed

docs/modules/ROOT/pages/includes/quarkus-langchain4j-bedrock.adoc

Lines changed: 4007 additions & 0 deletions
Large diffs are not rendered by default.

docs/modules/ROOT/pages/includes/quarkus-langchain4j-bedrock_quarkus.langchain4j.adoc

Lines changed: 4007 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
3+
<modelVersion>4.0.0</modelVersion>
4+
<parent>
5+
<groupId>io.quarkiverse.langchain4j</groupId>
6+
<artifactId>quarkus-langchain4j-bedrock-parent</artifactId>
7+
<version>999-SNAPSHOT</version>
8+
</parent>
9+
<artifactId>quarkus-langchain4j-bedrock-deployment</artifactId>
10+
<name>Quarkus LangChain4j - Bedrock - Deployment</name>
11+
<dependencies>
12+
<dependency>
13+
<groupId>io.quarkiverse.langchain4j</groupId>
14+
<artifactId>quarkus-langchain4j-bedrock</artifactId>
15+
<version>${project.version}</version>
16+
</dependency>
17+
<dependency>
18+
<groupId>io.quarkus</groupId>
19+
<artifactId>quarkus-rest-client-jackson-deployment</artifactId>
20+
</dependency>
21+
<dependency>
22+
<groupId>io.quarkiverse.langchain4j</groupId>
23+
<artifactId>quarkus-langchain4j-core-deployment</artifactId>
24+
<version>${project.version}</version>
25+
</dependency>
26+
<dependency>
27+
<groupId>org.wiremock</groupId>
28+
<artifactId>wiremock-standalone</artifactId>
29+
<version>${wiremock.version}</version>
30+
<scope>test</scope>
31+
</dependency>
32+
<dependency>
33+
<groupId>io.quarkus</groupId>
34+
<artifactId>quarkus-junit5-internal</artifactId>
35+
<scope>test</scope>
36+
</dependency>
37+
<dependency>
38+
<groupId>org.assertj</groupId>
39+
<artifactId>assertj-core</artifactId>
40+
<version>${assertj.version}</version>
41+
<scope>test</scope>
42+
</dependency>
43+
</dependencies>
44+
<build>
45+
<plugins>
46+
<plugin>
47+
<artifactId>maven-compiler-plugin</artifactId>
48+
<configuration>
49+
<annotationProcessorPaths>
50+
<path>
51+
<groupId>io.quarkus</groupId>
52+
<artifactId>quarkus-extension-processor</artifactId>
53+
</path>
54+
</annotationProcessorPaths>
55+
</configuration>
56+
</plugin>
57+
</plugins>
58+
</build>
59+
</project>
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package io.quarkiverse.langchain4j.bedrock.deployment;
2+
3+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL;
4+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.EMBEDDING_MODEL;
5+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.STREAMING_CHAT_MODEL;
6+
7+
import java.util.List;
8+
9+
import jakarta.enterprise.context.ApplicationScoped;
10+
11+
import org.jboss.jandex.AnnotationInstance;
12+
13+
import io.quarkiverse.langchain4j.ModelName;
14+
import io.quarkiverse.langchain4j.bedrock.runtime.BedrockRecorder;
15+
import io.quarkiverse.langchain4j.bedrock.runtime.config.LangChain4jBedrockConfig;
16+
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
17+
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
18+
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
19+
import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem;
20+
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
21+
import io.quarkiverse.langchain4j.runtime.config.LangChain4jConfig;
22+
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
23+
import io.quarkus.deployment.Capabilities;
24+
import io.quarkus.deployment.Capability;
25+
import io.quarkus.deployment.annotations.BuildProducer;
26+
import io.quarkus.deployment.annotations.BuildStep;
27+
import io.quarkus.deployment.annotations.ExecutionTime;
28+
import io.quarkus.deployment.annotations.Record;
29+
import io.quarkus.deployment.builditem.FeatureBuildItem;
30+
import io.quarkus.deployment.builditem.nativeimage.ServiceProviderBuildItem;
31+
import io.quarkus.resteasy.reactive.spi.MessageBodyReaderOverrideBuildItem;
32+
import io.quarkus.resteasy.reactive.spi.MessageBodyWriterOverrideBuildItem;
33+
import io.smallrye.config.ConfigSourceInterceptor;
34+
import io.smallrye.config.Priorities;
35+
36+
public class BedrockProcessor {
37+
38+
private static final String FEATURE = "langchain4j-bedrock";
39+
private static final String PROVIDER = "bedrock";
40+
41+
@BuildStep
42+
FeatureBuildItem feature() {
43+
return new FeatureBuildItem(FEATURE);
44+
}
45+
46+
@BuildStep
47+
void nativeSupport(BuildProducer<ServiceProviderBuildItem> serviceProviderProducer) {
48+
serviceProviderProducer
49+
.produce(ServiceProviderBuildItem.allProvidersFromClassPath(ConfigSourceInterceptor.class.getName()));
50+
}
51+
52+
@BuildStep
53+
public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem> chatProducer,
54+
BuildProducer<EmbeddingModelProviderCandidateBuildItem> embeddingProducer,
55+
LangChain4jBedrockBuildConfig config) {
56+
if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) {
57+
chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER));
58+
}
59+
if (config.embeddingModel().enabled().isEmpty() || config.embeddingModel().enabled().get()) {
60+
embeddingProducer.produce(new EmbeddingModelProviderCandidateBuildItem(PROVIDER));
61+
}
62+
}
63+
64+
@BuildStep
65+
@Record(ExecutionTime.RUNTIME_INIT)
66+
void generateBeans(BedrockRecorder recorder,
67+
List<SelectedChatModelProviderBuildItem> selectedChatItem,
68+
List<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
69+
LangChain4jBedrockConfig config,
70+
LangChain4jConfig rootConfig,
71+
BuildProducer<SyntheticBeanBuildItem> beanProducer) {
72+
73+
for (var selected : selectedChatItem) {
74+
if (PROVIDER.equals(selected.getProvider())) {
75+
String configName = selected.getConfigName();
76+
var builder = SyntheticBeanBuildItem
77+
.configure(CHAT_MODEL)
78+
.setRuntimeInit()
79+
.defaultBean()
80+
.scope(ApplicationScoped.class)
81+
.supplier(recorder.chatModel(config, configName, rootConfig));
82+
83+
addQualifierIfNecessary(builder, configName);
84+
beanProducer.produce(builder.done());
85+
86+
var streamingBuilder = SyntheticBeanBuildItem
87+
.configure(STREAMING_CHAT_MODEL)
88+
.setRuntimeInit()
89+
.defaultBean()
90+
.scope(ApplicationScoped.class)
91+
.supplier(recorder.streamingChatModel(config, configName, rootConfig));
92+
addQualifierIfNecessary(streamingBuilder, configName);
93+
beanProducer.produce(streamingBuilder.done());
94+
}
95+
}
96+
97+
for (var selected : selectedEmbedding) {
98+
if (PROVIDER.equals(selected.getProvider())) {
99+
String configName = selected.getConfigName();
100+
var builder = SyntheticBeanBuildItem
101+
.configure(EMBEDDING_MODEL)
102+
.setRuntimeInit()
103+
.defaultBean()
104+
.unremovable()
105+
.scope(ApplicationScoped.class)
106+
.supplier(recorder.embeddingModel(config, configName, rootConfig));
107+
addQualifierIfNecessary(builder, configName);
108+
beanProducer.produce(builder.done());
109+
}
110+
}
111+
}
112+
113+
private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String configName) {
114+
if (!NamedConfigUtil.isDefault(configName)) {
115+
builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", configName).build());
116+
}
117+
}
118+
119+
/**
120+
* When both {@code rest-client-jackson} and {@code rest-client-jsonb} are present on the classpath we need to make sure
121+
* that Jackson is used.
122+
* This is not a proper solution as it affects all clients, but it's better than the having the reader/writers be selected
123+
* at random.
124+
*/
125+
@BuildStep
126+
public void deprioritizeJsonb(Capabilities capabilities,
127+
BuildProducer<MessageBodyReaderOverrideBuildItem> readerOverrideProducer,
128+
BuildProducer<MessageBodyWriterOverrideBuildItem> writerOverrideProducer) {
129+
if (capabilities.isPresent(Capability.REST_CLIENT_REACTIVE_JSONB)) {
130+
readerOverrideProducer.produce(
131+
new MessageBodyReaderOverrideBuildItem("org.jboss.resteasy.reactive.server.jsonb.JsonbMessageBodyReader",
132+
Priorities.APPLICATION + 1, true));
133+
writerOverrideProducer.produce(new MessageBodyWriterOverrideBuildItem(
134+
"org.jboss.resteasy.reactive.server.jsonb.JsonbMessageBodyWriter", Priorities.APPLICATION + 1, true));
135+
}
136+
}
137+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.bedrock.deployment;
2+
3+
import java.util.Optional;
4+
5+
import io.quarkus.runtime.annotations.ConfigDocDefault;
6+
import io.quarkus.runtime.annotations.ConfigGroup;
7+
8+
@ConfigGroup
9+
public interface ChatModelBuildConfig {
10+
11+
/**
12+
* Whether the model should be enabled
13+
*/
14+
@ConfigDocDefault("true")
15+
Optional<Boolean> enabled();
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.bedrock.deployment;
2+
3+
import java.util.Optional;
4+
5+
import io.quarkus.runtime.annotations.ConfigDocDefault;
6+
import io.quarkus.runtime.annotations.ConfigGroup;
7+
8+
@ConfigGroup
9+
public interface EmbeddingModelBuildConfig {
10+
11+
/**
12+
* Whether the model should be enabled
13+
*/
14+
@ConfigDocDefault("true")
15+
Optional<Boolean> enabled();
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package io.quarkiverse.langchain4j.bedrock.deployment;
2+
3+
import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;
4+
5+
import io.quarkus.runtime.annotations.ConfigRoot;
6+
import io.smallrye.config.ConfigMapping;
7+
8+
@ConfigRoot(phase = BUILD_TIME)
9+
@ConfigMapping(prefix = "quarkus.langchain4j.bedrock")
10+
public interface LangChain4jBedrockBuildConfig {
11+
12+
/**
13+
* Chat model related settings
14+
*/
15+
ChatModelBuildConfig chatModel();
16+
17+
/**
18+
* Embedding model related settings
19+
*/
20+
EmbeddingModelBuildConfig embeddingModel();
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package io.quarkiverse.langchain4j.bedrock.deployment;
2+
3+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
4+
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
5+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
6+
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
7+
import static io.quarkiverse.langchain4j.bedrock.deployment.BedrockStreamHelper.createCompletion;
8+
import static io.quarkiverse.langchain4j.bedrock.deployment.BedrockStreamHelper.decode;
9+
import static org.assertj.core.api.Assertions.assertThat;
10+
11+
import java.util.List;
12+
13+
import jakarta.inject.Inject;
14+
15+
import org.jboss.shrinkwrap.api.ShrinkWrap;
16+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
17+
import org.junit.jupiter.api.DisplayNameGeneration;
18+
import org.junit.jupiter.api.DisplayNameGenerator;
19+
import org.junit.jupiter.api.Test;
20+
import org.junit.jupiter.api.extension.RegisterExtension;
21+
22+
import dev.langchain4j.model.bedrock.BedrockAnthropicStreamingChatModel;
23+
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
24+
import io.quarkus.arc.ClientProxy;
25+
import io.quarkus.test.QuarkusUnitTest;
26+
27+
@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
28+
class BedrockAntrophicStreamingChatModelTest extends BedrockTestBase {
29+
30+
@RegisterExtension
31+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
32+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
33+
.addClass(TestCredentialsProvider.class)
34+
.addClass(BedrockStreamHelper.class))
35+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.chat-model.model-id", "anthropic.claude-v2")
36+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.chat-model.aws.region", "eu-central-1")
37+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.chat-model.aws.endpoint-override",
38+
"http://localhost:%d".formatted(WM_PORT))
39+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.chat-model.aws.credentials-provider",
40+
"TestCredentialsProvider")
41+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.log-requests", "true")
42+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.log-responses", "true");
43+
44+
@Inject
45+
StreamingChatLanguageModel streamingChatModel;
46+
47+
@Test
48+
void should_create_bedrock_model() {
49+
// given
50+
51+
// when
52+
53+
// then
54+
assertThat(ClientProxy.unwrap(streamingChatModel)).isInstanceOf(BedrockAnthropicStreamingChatModel.class);
55+
}
56+
57+
@Test
58+
void should_answer_a_chat_message() throws Throwable {
59+
// given
60+
var helper = BedrockStreamHelper.create();
61+
var expected = List.of(
62+
createCompletion("Hello, how are you today?"));
63+
stubFor(post(anyUrl())
64+
.willReturn(aResponse()
65+
.withStatus(200)
66+
.withHeader("Content-Type", "application/vnd.amazon.eventstream") //
67+
.withHeader("Transfer-Encoding", "chunked") //
68+
.withHeader("Connection", "keep-alive")
69+
.withBody(decode(expected))));
70+
71+
// when
72+
streamingChatModel.chat("Hello, how are you today?", helper);
73+
var response = helper.awaitResponse();
74+
75+
// then
76+
assertThat(response).isNotNull();
77+
assertThat(response.aiMessage()).isNotNull();
78+
assertThat(response.aiMessage().text()).isEqualTo("Hello, how are you today?");
79+
}
80+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package io.quarkiverse.langchain4j.bedrock.deployment;
2+
3+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
4+
5+
import jakarta.inject.Inject;
6+
7+
import org.assertj.core.api.ThrowableAssert;
8+
import org.jboss.shrinkwrap.api.ShrinkWrap;
9+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
10+
import org.junit.jupiter.api.DisplayNameGeneration;
11+
import org.junit.jupiter.api.DisplayNameGenerator;
12+
import org.junit.jupiter.api.Test;
13+
import org.junit.jupiter.api.extension.RegisterExtension;
14+
15+
import dev.langchain4j.model.ModelDisabledException;
16+
import dev.langchain4j.model.chat.ChatLanguageModel;
17+
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
18+
import io.quarkus.test.QuarkusUnitTest;
19+
20+
@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
21+
class BedrockChatModelDisabledTest extends BedrockTestBase {
22+
23+
@RegisterExtension
24+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
25+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
26+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.enable-integration", "false");
27+
28+
@Inject
29+
ChatLanguageModel chatModel;
30+
31+
@Inject
32+
StreamingChatLanguageModel streamingChatModel;
33+
34+
@Test
35+
void should_disable_chat_model() {
36+
// given
37+
38+
// when
39+
ThrowableAssert.ThrowingCallable callable = () -> chatModel.chat("Hello, how are you today?");
40+
41+
// then
42+
assertThatThrownBy(callable).isInstanceOf(ModelDisabledException.class);
43+
}
44+
45+
@Test
46+
void should_disable_streaming_chat_model() {
47+
// given
48+
49+
// when
50+
ThrowableAssert.ThrowingCallable callable = () -> streamingChatModel.chat("Hello, how are you today?", null);
51+
52+
// then
53+
assertThatThrownBy(callable).isInstanceOf(ModelDisabledException.class);
54+
}
55+
}

0 commit comments

Comments
 (0)