Skip to content

Commit 731064c

Browse files
committed
feat: Add Amazon Bedrock
- Add ChatModel via Converse Api - Add StreamChatModel via Converse Api except for Anthropic - Add EmbeddingModel for Titan and Cohere via Invoke Api
1 parent cb61158 commit 731064c

28 files changed

+4842
-0
lines changed

docs/modules/ROOT/pages/includes/quarkus-langchain4j-bedrock.adoc

Lines changed: 1433 additions & 0 deletions
Large diffs are not rendered by default.

docs/modules/ROOT/pages/includes/quarkus-langchain4j-bedrock_quarkus.langchain4j.adoc

Lines changed: 1433 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
3+
<modelVersion>4.0.0</modelVersion>
4+
<parent>
5+
<groupId>io.quarkiverse.langchain4j</groupId>
6+
<artifactId>quarkus-langchain4j-bedrock-parent</artifactId>
7+
<version>999-SNAPSHOT</version>
8+
</parent>
9+
<artifactId>quarkus-langchain4j-bedrock-deployment</artifactId>
10+
<name>Quarkus LangChain4j - Bedrock - Deployment</name>
11+
<dependencies>
12+
<dependency>
13+
<groupId>io.quarkiverse.langchain4j</groupId>
14+
<artifactId>quarkus-langchain4j-bedrock</artifactId>
15+
<version>${project.version}</version>
16+
</dependency>
17+
<dependency>
18+
<groupId>io.quarkus</groupId>
19+
<artifactId>quarkus-rest-client-jackson-deployment</artifactId>
20+
</dependency>
21+
<dependency>
22+
<groupId>io.quarkiverse.langchain4j</groupId>
23+
<artifactId>quarkus-langchain4j-core-deployment</artifactId>
24+
<version>${project.version}</version>
25+
</dependency>
26+
<dependency>
27+
<groupId>org.wiremock</groupId>
28+
<artifactId>wiremock-standalone</artifactId>
29+
<version>${wiremock.version}</version>
30+
<scope>test</scope>
31+
</dependency>
32+
<dependency>
33+
<groupId>io.quarkus</groupId>
34+
<artifactId>quarkus-junit5-internal</artifactId>
35+
<scope>test</scope>
36+
</dependency>
37+
<dependency>
38+
<groupId>org.assertj</groupId>
39+
<artifactId>assertj-core</artifactId>
40+
<version>${assertj.version}</version>
41+
<scope>test</scope>
42+
</dependency>
43+
</dependencies>
44+
<build>
45+
<plugins>
46+
<plugin>
47+
<artifactId>maven-compiler-plugin</artifactId>
48+
<configuration>
49+
<annotationProcessorPaths>
50+
<path>
51+
<groupId>io.quarkus</groupId>
52+
<artifactId>quarkus-extension-processor</artifactId>
53+
</path>
54+
</annotationProcessorPaths>
55+
</configuration>
56+
</plugin>
57+
</plugins>
58+
</build>
59+
</project>
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package io.quarkiverse.langchain4j.bedrock.deployment;
2+
3+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.CHAT_MODEL;
4+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.EMBEDDING_MODEL;
5+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.STREAMING_CHAT_MODEL;
6+
7+
import java.util.List;
8+
9+
import jakarta.enterprise.context.ApplicationScoped;
10+
11+
import org.jboss.jandex.AnnotationInstance;
12+
13+
import io.quarkiverse.langchain4j.ModelName;
14+
import io.quarkiverse.langchain4j.bedrock.runtime.BedrockRecorder;
15+
import io.quarkiverse.langchain4j.bedrock.runtime.config.LangChain4jBedrockConfig;
16+
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
17+
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
18+
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
19+
import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem;
20+
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
21+
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
22+
import io.quarkus.deployment.Capabilities;
23+
import io.quarkus.deployment.Capability;
24+
import io.quarkus.deployment.annotations.BuildProducer;
25+
import io.quarkus.deployment.annotations.BuildStep;
26+
import io.quarkus.deployment.annotations.ExecutionTime;
27+
import io.quarkus.deployment.annotations.Record;
28+
import io.quarkus.deployment.builditem.FeatureBuildItem;
29+
import io.quarkus.deployment.builditem.nativeimage.ServiceProviderBuildItem;
30+
import io.quarkus.resteasy.reactive.spi.MessageBodyReaderOverrideBuildItem;
31+
import io.quarkus.resteasy.reactive.spi.MessageBodyWriterOverrideBuildItem;
32+
import io.smallrye.config.ConfigSourceInterceptor;
33+
import io.smallrye.config.Priorities;
34+
35+
public class BedrockProcessor {
36+
37+
private static final String FEATURE = "langchain4j-bedrock";
38+
private static final String PROVIDER = "bedrock";
39+
40+
@BuildStep
41+
FeatureBuildItem feature() {
42+
return new FeatureBuildItem(FEATURE);
43+
}
44+
45+
@BuildStep
46+
void nativeSupport(BuildProducer<ServiceProviderBuildItem> serviceProviderProducer) {
47+
serviceProviderProducer
48+
.produce(ServiceProviderBuildItem.allProvidersFromClassPath(ConfigSourceInterceptor.class.getName()));
49+
}
50+
51+
@BuildStep
52+
public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem> chatProducer,
53+
BuildProducer<EmbeddingModelProviderCandidateBuildItem> embeddingProducer,
54+
LangChain4jBedrockBuildConfig config) {
55+
if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) {
56+
chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER));
57+
}
58+
if (config.embeddingModel().enabled().isEmpty() || config.embeddingModel().enabled().get()) {
59+
embeddingProducer.produce(new EmbeddingModelProviderCandidateBuildItem(PROVIDER));
60+
}
61+
}
62+
63+
@BuildStep
64+
@Record(ExecutionTime.RUNTIME_INIT)
65+
void generateBeans(BedrockRecorder recorder,
66+
List<SelectedChatModelProviderBuildItem> selectedChatItem,
67+
List<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
68+
LangChain4jBedrockConfig config,
69+
BuildProducer<SyntheticBeanBuildItem> beanProducer) {
70+
71+
for (var selected : selectedChatItem) {
72+
if (PROVIDER.equals(selected.getProvider())) {
73+
String configName = selected.getConfigName();
74+
var builder = SyntheticBeanBuildItem
75+
.configure(CHAT_MODEL)
76+
.setRuntimeInit()
77+
.defaultBean()
78+
.scope(ApplicationScoped.class)
79+
.supplier(recorder.chatModel(config, configName));
80+
81+
addQualifierIfNecessary(builder, configName);
82+
beanProducer.produce(builder.done());
83+
84+
var streamingBuilder = SyntheticBeanBuildItem
85+
.configure(STREAMING_CHAT_MODEL)
86+
.setRuntimeInit()
87+
.defaultBean()
88+
.scope(ApplicationScoped.class)
89+
.supplier(recorder.streamingChatModel(config, configName));
90+
addQualifierIfNecessary(streamingBuilder, configName);
91+
beanProducer.produce(streamingBuilder.done());
92+
}
93+
}
94+
95+
for (var selected : selectedEmbedding) {
96+
if (PROVIDER.equals(selected.getProvider())) {
97+
String configName = selected.getConfigName();
98+
var builder = SyntheticBeanBuildItem
99+
.configure(EMBEDDING_MODEL)
100+
.setRuntimeInit()
101+
.defaultBean()
102+
.unremovable()
103+
.scope(ApplicationScoped.class)
104+
.supplier(recorder.embeddingModel(config, configName));
105+
addQualifierIfNecessary(builder, configName);
106+
beanProducer.produce(builder.done());
107+
}
108+
}
109+
}
110+
111+
private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String configName) {
112+
if (!NamedConfigUtil.isDefault(configName)) {
113+
builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", configName).build());
114+
}
115+
}
116+
117+
/**
118+
* When both {@code rest-client-jackson} and {@code rest-client-jsonb} are present on the classpath we need to make sure
119+
* that Jackson is used.
120+
* This is not a proper solution as it affects all clients, but it's better than the having the reader/writers be selected
121+
* at random.
122+
*/
123+
@BuildStep
124+
public void deprioritizeJsonb(Capabilities capabilities,
125+
BuildProducer<MessageBodyReaderOverrideBuildItem> readerOverrideProducer,
126+
BuildProducer<MessageBodyWriterOverrideBuildItem> writerOverrideProducer) {
127+
if (capabilities.isPresent(Capability.REST_CLIENT_REACTIVE_JSONB)) {
128+
readerOverrideProducer.produce(
129+
new MessageBodyReaderOverrideBuildItem("org.jboss.resteasy.reactive.server.jsonb.JsonbMessageBodyReader",
130+
Priorities.APPLICATION + 1, true));
131+
writerOverrideProducer.produce(new MessageBodyWriterOverrideBuildItem(
132+
"org.jboss.resteasy.reactive.server.jsonb.JsonbMessageBodyWriter", Priorities.APPLICATION + 1, true));
133+
}
134+
}
135+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.bedrock.deployment;
2+
3+
import java.util.Optional;
4+
5+
import io.quarkus.runtime.annotations.ConfigDocDefault;
6+
import io.quarkus.runtime.annotations.ConfigGroup;
7+
8+
@ConfigGroup
9+
public interface ChatModelBuildConfig {
10+
11+
/**
12+
* Whether the model should be enabled
13+
*/
14+
@ConfigDocDefault("true")
15+
Optional<Boolean> enabled();
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.bedrock.deployment;
2+
3+
import java.util.Optional;
4+
5+
import io.quarkus.runtime.annotations.ConfigDocDefault;
6+
import io.quarkus.runtime.annotations.ConfigGroup;
7+
8+
@ConfigGroup
9+
public interface EmbeddingModelBuildConfig {
10+
11+
/**
12+
* Whether the model should be enabled
13+
*/
14+
@ConfigDocDefault("true")
15+
Optional<Boolean> enabled();
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package io.quarkiverse.langchain4j.bedrock.deployment;
2+
3+
import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;
4+
5+
import io.quarkus.runtime.annotations.ConfigRoot;
6+
import io.smallrye.config.ConfigMapping;
7+
8+
@ConfigRoot(phase = BUILD_TIME)
9+
@ConfigMapping(prefix = "quarkus.langchain4j.bedrock")
10+
public interface LangChain4jBedrockBuildConfig {
11+
12+
/**
13+
* Chat model related settings
14+
*/
15+
ChatModelBuildConfig chatModel();
16+
17+
/**
18+
* Embedding model related settings
19+
*/
20+
EmbeddingModelBuildConfig embeddingModel();
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package io.quarkiverse.langchain4j.bedrock.deployment;
2+
3+
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
4+
import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl;
5+
import static com.github.tomakehurst.wiremock.client.WireMock.post;
6+
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
7+
import static io.quarkiverse.langchain4j.bedrock.deployment.BedrockStreamHelper.createCompletion;
8+
import static io.quarkiverse.langchain4j.bedrock.deployment.BedrockStreamHelper.decode;
9+
import static org.assertj.core.api.Assertions.assertThat;
10+
11+
import java.util.List;
12+
13+
import jakarta.inject.Inject;
14+
15+
import org.jboss.shrinkwrap.api.ShrinkWrap;
16+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
17+
import org.junit.jupiter.api.DisplayNameGeneration;
18+
import org.junit.jupiter.api.DisplayNameGenerator;
19+
import org.junit.jupiter.api.Test;
20+
import org.junit.jupiter.api.extension.RegisterExtension;
21+
22+
import dev.langchain4j.model.bedrock.BedrockAnthropicStreamingChatModel;
23+
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
24+
import io.quarkus.arc.ClientProxy;
25+
import io.quarkus.test.QuarkusUnitTest;
26+
27+
@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
28+
class BedrockAntrophicStreamingChatModelTest extends BedrockTestBase {
29+
30+
@RegisterExtension
31+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
32+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
33+
.addClass(TestCredentialsProvider.class)
34+
.addClass(BedrockStreamHelper.class))
35+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.chat-model.model-id", "anthropic.claude-v2")
36+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.chat-model.client.region", "eu-central-1")
37+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.chat-model.client.endpoint-override",
38+
"http://localhost:%d".formatted(WM_PORT))
39+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.chat-model.client.credentials-provider",
40+
"TestCredentialsProvider")
41+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.log-requests", "true")
42+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.log-responses", "true");
43+
44+
@Inject
45+
StreamingChatLanguageModel streamingChatModel;
46+
47+
@Test
48+
void should_create_bedrock_model() {
49+
// given
50+
51+
// when
52+
53+
// then
54+
assertThat(ClientProxy.unwrap(streamingChatModel)).isInstanceOf(BedrockAnthropicStreamingChatModel.class);
55+
}
56+
57+
@Test
58+
void should_answer_a_chat_message() throws Throwable {
59+
// given
60+
var helper = BedrockStreamHelper.create();
61+
var expected = List.of(
62+
createCompletion("Hello, how are you today?"));
63+
stubFor(post(anyUrl())
64+
.willReturn(aResponse()
65+
.withStatus(200)
66+
.withHeader("Content-Type", "application/vnd.amazon.eventstream") //
67+
.withHeader("Transfer-Encoding", "chunked") //
68+
.withHeader("Connection", "keep-alive")
69+
.withBody(decode(expected))));
70+
71+
// when
72+
streamingChatModel.chat("Hello, how are you today?", helper);
73+
var response = helper.awaitResponse();
74+
75+
// then
76+
assertThat(response).isNotNull();
77+
assertThat(response.aiMessage()).isNotNull();
78+
assertThat(response.aiMessage().text()).isEqualTo("Hello, how are you today?");
79+
}
80+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package io.quarkiverse.langchain4j.bedrock.deployment;
2+
3+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
4+
5+
import jakarta.inject.Inject;
6+
7+
import org.assertj.core.api.ThrowableAssert;
8+
import org.jboss.shrinkwrap.api.ShrinkWrap;
9+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
10+
import org.junit.jupiter.api.DisplayNameGeneration;
11+
import org.junit.jupiter.api.DisplayNameGenerator;
12+
import org.junit.jupiter.api.Test;
13+
import org.junit.jupiter.api.extension.RegisterExtension;
14+
15+
import dev.langchain4j.model.ModelDisabledException;
16+
import dev.langchain4j.model.chat.ChatLanguageModel;
17+
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
18+
import io.quarkus.test.QuarkusUnitTest;
19+
20+
@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
21+
class BedrockChatModelDisabledTest extends BedrockTestBase {
22+
23+
@RegisterExtension
24+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
25+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
26+
.overrideRuntimeConfigKey("quarkus.langchain4j.bedrock.enable-integration", "false");
27+
28+
@Inject
29+
ChatLanguageModel chatModel;
30+
31+
@Inject
32+
StreamingChatLanguageModel streamingChatModel;
33+
34+
@Test
35+
void should_disable_chat_model() {
36+
// given
37+
38+
// when
39+
ThrowableAssert.ThrowingCallable callable = () -> chatModel.chat("Hello, how are you today?");
40+
41+
// then
42+
assertThatThrownBy(callable).isInstanceOf(ModelDisabledException.class);
43+
}
44+
45+
@Test
46+
void should_disable_streaming_chat_model() {
47+
// given
48+
49+
// when
50+
ThrowableAssert.ThrowingCallable callable = () -> streamingChatModel.chat("Hello, how are you today?", null);
51+
52+
// then
53+
assertThatThrownBy(callable).isInstanceOf(ModelDisabledException.class);
54+
}
55+
}

0 commit comments

Comments
 (0)