Skip to content

Commit b91f56d

Browse files
andreadimaiogeoand
authored andcommitted
Add embedding model
1 parent 011ab0f commit b91f56d

File tree

11 files changed

+237
-0
lines changed

11 files changed

+237
-0
lines changed

bam/deployment/src/main/java/io/quarkiverse/langchain4j/bam/deployment/BamProcessor.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.quarkiverse.langchain4j.bam.deployment;
22

33
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL;
4+
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL;
45

56
import java.util.Optional;
67

@@ -9,7 +10,9 @@
910
import io.quarkiverse.langchain4j.bam.runtime.BamRecorder;
1011
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
1112
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
13+
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
1214
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
15+
import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem;
1316
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
1417
import io.quarkus.deployment.annotations.BuildProducer;
1518
import io.quarkus.deployment.annotations.BuildStep;
@@ -30,19 +33,27 @@ FeatureBuildItem feature() {
3033

3134
@BuildStep
3235
public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem> chatProducer,
36+
BuildProducer<EmbeddingModelProviderCandidateBuildItem> embeddingProducer,
3337
Langchain4jBamBuildConfig config) {
38+
3439
if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) {
3540
chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER));
3641
}
42+
43+
if (config.embeddingModel().enabled().isEmpty() || config.embeddingModel().enabled().get()) {
44+
embeddingProducer.produce(new EmbeddingModelProviderCandidateBuildItem(PROVIDER));
45+
}
3746
}
3847

3948
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
4049
@BuildStep
4150
@Record(ExecutionTime.RUNTIME_INIT)
4251
void generateBeans(BamRecorder recorder,
4352
Optional<SelectedChatModelProviderBuildItem> selectedChatItem,
53+
Optional<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
4454
Langchain4jBamConfig config,
4555
BuildProducer<SyntheticBeanBuildItem> beanProducer) {
56+
4657
if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) {
4758
beanProducer.produce(SyntheticBeanBuildItem
4859
.configure(CHAT_MODEL)
@@ -52,5 +63,17 @@ void generateBeans(BamRecorder recorder,
5263
.supplier(recorder.chatModel(config))
5364
.done());
5465
}
66+
67+
if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) {
68+
beanProducer.produce(
69+
SyntheticBeanBuildItem
70+
.configure(EMBEDDING_MODEL)
71+
.setRuntimeInit()
72+
.defaultBean()
73+
.scope(ApplicationScoped.class)
74+
.supplier(recorder.embeddingModel(config))
75+
.unremovable()
76+
.done());
77+
}
5578
}
5679
}

bam/deployment/src/main/java/io/quarkiverse/langchain4j/bam/deployment/ChatModelBuildConfig.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,9 @@ public interface ChatModelBuildConfig {
1313
*/
1414
@ConfigDocDefault("true")
1515
Optional<Boolean> enabled();
16+
17+
/**
18+
* Embedding model related settings
19+
*/
20+
EmbeddingModelBuildConfig embeddingModel();
1621
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.bam.deployment;
2+
3+
import java.util.Optional;
4+
5+
import io.quarkus.runtime.annotations.ConfigDocDefault;
6+
import io.quarkus.runtime.annotations.ConfigGroup;
7+
8+
@ConfigGroup
9+
public interface EmbeddingModelBuildConfig {
10+
11+
/**
12+
* Whether the model should be enabled
13+
*/
14+
@ConfigDocDefault("true")
15+
Optional<Boolean> enabled();
16+
}

bam/deployment/src/main/java/io/quarkiverse/langchain4j/bam/deployment/Langchain4jBamBuildConfig.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,9 @@ public interface Langchain4jBamBuildConfig {
1313
* Chat model related settings
1414
*/
1515
ChatModelBuildConfig chatModel();
16+
17+
/**
18+
* Embedding model related settings
19+
*/
20+
EmbeddingModelBuildConfig embeddingModel();
1621
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package io.quarkiverse.langchain4j.bam;
2+
3+
import java.net.URI;
4+
import java.net.URISyntaxException;
5+
import java.net.URL;
6+
import java.time.Duration;
7+
import java.util.ArrayList;
8+
import java.util.List;
9+
import java.util.concurrent.TimeUnit;
10+
11+
import org.jboss.resteasy.reactive.client.api.LoggingScope;
12+
13+
import dev.langchain4j.data.embedding.Embedding;
14+
import dev.langchain4j.data.segment.TextSegment;
15+
import dev.langchain4j.model.embedding.EmbeddingModel;
16+
import dev.langchain4j.model.embedding.TokenCountEstimator;
17+
import dev.langchain4j.model.output.Response;
18+
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
19+
20+
public class BamEmbeddingModel implements EmbeddingModel, TokenCountEstimator {
21+
22+
private final String token;
23+
private final String modelId;
24+
private final String version;
25+
public boolean logResponses;
26+
public boolean logRequests;
27+
private final BamRestApi client;
28+
29+
public BamEmbeddingModel(Builder config) {
30+
31+
QuarkusRestClientBuilder builder = QuarkusRestClientBuilder.newBuilder()
32+
.baseUri(config.url)
33+
.connectTimeout(config.timeout.toSeconds(), TimeUnit.SECONDS)
34+
.readTimeout(config.timeout.toSeconds(), TimeUnit.SECONDS);
35+
36+
if (config.logRequests || config.logResponses) {
37+
builder.loggingScope(LoggingScope.REQUEST_RESPONSE);
38+
builder.clientLogger(new BamRestApi.WatsonClientLogger(
39+
config.logRequests,
40+
config.logResponses));
41+
}
42+
43+
this.client = builder.build(BamRestApi.class);
44+
this.token = config.accessToken;
45+
this.modelId = config.modelId;
46+
this.version = config.version;
47+
}
48+
49+
public static Builder builder() {
50+
return new Builder();
51+
}
52+
53+
@Override
54+
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
55+
56+
List<Embedding> result = new ArrayList<>();
57+
for (TextSegment textSegment : textSegments) {
58+
59+
var request = new EmbeddingRequest(modelId, textSegment.text());
60+
var response = client.embeddings(request, token, version);
61+
62+
var vector = response.results().get(0);
63+
result.add(Embedding.from(vector));
64+
}
65+
66+
return Response.from(result);
67+
}
68+
69+
@Override
70+
public int estimateTokenCount(String text) {
71+
72+
var request = new TokenizationRequest(modelId, text);
73+
return client.tokenization(request, token, version).tokenCount();
74+
}
75+
76+
public static final class Builder {
77+
78+
private String accessToken;
79+
private String version;
80+
private URI url = URI.create("https://bam-api.res.ibm.com");
81+
private Duration timeout = Duration.ofSeconds(15);
82+
private String modelId;
83+
public boolean logResponses;
84+
public boolean logRequests;
85+
86+
public Builder accessToken(String accessToken) {
87+
this.accessToken = accessToken;
88+
return this;
89+
}
90+
91+
public Builder version(String version) {
92+
this.version = version;
93+
return this;
94+
}
95+
96+
public Builder url(URL url) {
97+
try {
98+
this.url = url.toURI();
99+
} catch (URISyntaxException e) {
100+
throw new RuntimeException(e);
101+
}
102+
return this;
103+
}
104+
105+
public Builder timeout(Duration timeout) {
106+
this.timeout = timeout;
107+
return this;
108+
}
109+
110+
public Builder modelId(String modelId) {
111+
this.modelId = modelId;
112+
return this;
113+
}
114+
115+
public Builder logRequests(boolean logRequests) {
116+
this.logRequests = logRequests;
117+
return this;
118+
}
119+
120+
public Builder logResponses(boolean logResponses) {
121+
this.logResponses = logResponses;
122+
return this;
123+
}
124+
125+
public BamEmbeddingModel build() {
126+
return new BamEmbeddingModel(this);
127+
}
128+
}
129+
}

bam/runtime/src/main/java/io/quarkiverse/langchain4j/bam/BamRestApi.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import static java.util.stream.Collectors.joining;
44
import static java.util.stream.StreamSupport.stream;
55

6+
import java.util.Optional;
67
import java.util.regex.Matcher;
78
import java.util.regex.Pattern;
89

@@ -46,6 +47,10 @@ public interface BamRestApi {
4647
@Path("text/chat")
4748
TextGenerationResponse chat(TextGenerationRequest request, @NotBody String token, @QueryParam("version") String version);
4849

50+
@POST
51+
@Path("/text/embeddings")
52+
EmbeddingResponse embeddings(EmbeddingRequest request, @NotBody String token, @QueryParam("version") String version);
53+
4954
@POST
5055
@Path("/text/tokenization")
5156
public TokenizationResponse tokenization(TokenizationRequest request, @NotBody String token,
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package io.quarkiverse.langchain4j.bam;
2+
3+
public record EmbeddingRequest(String modelId, String input) {
4+
5+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package io.quarkiverse.langchain4j.bam;
2+
3+
import java.util.List;
4+
5+
public record EmbeddingResponse(List<List<Float>> results) {
6+
7+
}

bam/runtime/src/main/java/io/quarkiverse/langchain4j/bam/runtime/BamRecorder.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault;
44
import java.util.function.Supplier;
55
import io.quarkiverse.langchain4j.bam.BamChatModel;
6+
import io.quarkiverse.langchain4j.bam.BamEmbeddingModel;
67
import io.quarkiverse.langchain4j.bam.runtime.config.ChatModelConfig;
78
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
89
import io.quarkus.runtime.annotations.Recorder;
@@ -46,4 +47,26 @@ public Object get() {
4647
}
4748
};
4849
}
50+
51+
public Supplier<?> embeddingModel(Langchain4jBamConfig runtimeConfig) {
52+
53+
var embeddingModelConfig = runtimeConfig.embeddingModel();
54+
55+
var builder = BamEmbeddingModel.builder()
56+
.accessToken(runtimeConfig.apiKey())
57+
.timeout(runtimeConfig.timeout())
58+
.version(runtimeConfig.version())
59+
.modelId(embeddingModelConfig.modelId());
60+
61+
if (runtimeConfig.baseUrl().isPresent()) {
62+
builder.url(runtimeConfig.baseUrl().get());
63+
}
64+
65+
return new Supplier<>() {
66+
@Override
67+
public Object get() {
68+
return builder.build();
69+
}
70+
};
71+
}
4972
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package io.quarkiverse.langchain4j.bam.runtime.config;
2+
3+
import io.quarkus.runtime.annotations.ConfigGroup;
4+
import io.smallrye.config.WithDefault;
5+
6+
@ConfigGroup
7+
public interface EmbeddingModelConfig {
8+
9+
/**
10+
* Model to use
11+
*/
12+
@WithDefault("ibm/slate.30m.english.rtrvr-26.10.2023")
13+
String modelId();
14+
}

0 commit comments

Comments
 (0)