Skip to content

Commit 4b321bb

Browse files
andreadimaiojmartisk
authored andcommitted
Manage ScoringModel in core module
1 parent fa11843 commit 4b321bb

File tree

6 files changed

+96
-1
lines changed

6 files changed

+96
-1
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.IMAGE_MODEL;
66
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.MODEL_NAME;
77
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.MODERATION_MODEL;
8+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SCORING_MODEL;
89
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.STREAMING_CHAT_MODEL;
910
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.TOKEN_COUNT_ESTIMATOR;
1011

@@ -24,6 +25,7 @@
2425
import dev.langchain4j.model.embedding.EmbeddingModel;
2526
import dev.langchain4j.model.image.ImageModel;
2627
import dev.langchain4j.model.moderation.ModerationModel;
28+
import dev.langchain4j.model.scoring.ScoringModel;
2729
import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
2830
import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig;
2931
import io.quarkiverse.langchain4j.deployment.items.AutoCreateEmbeddingModelBuildItem;
@@ -33,10 +35,12 @@
3335
import io.quarkiverse.langchain4j.deployment.items.InProcessEmbeddingBuildItem;
3436
import io.quarkiverse.langchain4j.deployment.items.ModerationModelProviderCandidateBuildItem;
3537
import io.quarkiverse.langchain4j.deployment.items.ProviderHolder;
38+
import io.quarkiverse.langchain4j.deployment.items.ScoringModelProviderCandidateBuildItem;
3639
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
3740
import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem;
3841
import io.quarkiverse.langchain4j.deployment.items.SelectedImageModelProviderBuildItem;
3942
import io.quarkiverse.langchain4j.deployment.items.SelectedModerationModelProviderBuildItem;
43+
import io.quarkiverse.langchain4j.deployment.items.SelectedScoringModelProviderBuildItem;
4044
import io.quarkiverse.langchain4j.runtime.LangChain4jRecorder;
4145
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
4246
import io.quarkus.arc.deployment.BeanDiscoveryFinishedBuildItem;
@@ -70,6 +74,7 @@ void indexDependencies(BuildProducer<IndexDependencyBuildItem> producer) {
7074
@BuildStep
7175
public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished,
7276
List<ChatModelProviderCandidateBuildItem> chatCandidateItems,
77+
List<ScoringModelProviderCandidateBuildItem> scoringCandidateItems,
7378
List<EmbeddingModelProviderCandidateBuildItem> embeddingCandidateItems,
7479
List<ModerationModelProviderCandidateBuildItem> moderationCandidateItems,
7580
List<ImageModelProviderCandidateBuildItem> imageCandidateItems,
@@ -79,13 +84,15 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
7984
LangChain4jBuildConfig buildConfig,
8085
Optional<AutoCreateEmbeddingModelBuildItem> autoCreateEmbeddingModelBuildItem,
8186
BuildProducer<SelectedChatModelProviderBuildItem> selectedChatProducer,
87+
BuildProducer<SelectedScoringModelProviderBuildItem> selectedScoringProducer,
8288
BuildProducer<SelectedEmbeddingModelCandidateBuildItem> selectedEmbeddingProducer,
8389
BuildProducer<SelectedModerationModelProviderBuildItem> selectedModerationProducer,
8490
BuildProducer<SelectedImageModelProviderBuildItem> selectedImageProducer,
8591
List<InProcessEmbeddingBuildItem> inProcessEmbeddingBuildItems) {
8692

8793
Set<String> requestedChatModels = new HashSet<>();
8894
Set<String> requestedStreamingChatModels = new HashSet<>();
95+
Set<String> requestScoringModels = new HashSet<>();
8996
Set<String> requestEmbeddingModels = new HashSet<>();
9097
Set<String> requestedModerationModels = new HashSet<>();
9198
Set<String> requestedImageModels = new HashSet<>();
@@ -98,6 +105,8 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
98105
requestedChatModels.add(modelName);
99106
} else if (STREAMING_CHAT_MODEL.equals(requiredName)) {
100107
requestedStreamingChatModels.add(modelName);
108+
} else if (SCORING_MODEL.equals(requiredName)) {
109+
requestScoringModels.add(modelName);
101110
} else if (EMBEDDING_MODEL.equals(requiredName)) {
102111
requestEmbeddingModels.add(modelName);
103112
} else if (MODERATION_MODEL.equals(requiredName)) {
@@ -108,7 +117,9 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
108117
tokenCountEstimators.add(modelName);
109118
}
110119
}
111-
for (var bi : requestChatModelBeanItems) {
120+
for (
121+
122+
var bi : requestChatModelBeanItems) {
112123
requestedChatModels.add(bi.getConfigName());
113124
}
114125
for (var bi : requestModerationModelBeanBuildItems) {
@@ -150,6 +161,32 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
150161

151162
}
152163

164+
for (String modelName : requestScoringModels) {
165+
Optional<String> userSelectedProvider;
166+
String configNamespace;
167+
if (NamedConfigUtil.isDefault(modelName)) {
168+
userSelectedProvider = buildConfig.defaultConfig().scoringModel().provider();
169+
configNamespace = "scoring-model";
170+
} else {
171+
if (buildConfig.namedConfig().containsKey(modelName)) {
172+
userSelectedProvider = buildConfig.namedConfig().get(modelName).scoringModel().provider();
173+
} else {
174+
userSelectedProvider = Optional.empty();
175+
}
176+
configNamespace = modelName + ".scoring-model";
177+
}
178+
179+
String provider = selectProvider(
180+
scoringCandidateItems,
181+
beanDiscoveryFinished.beanStream().withBeanType(ScoringModel.class),
182+
userSelectedProvider,
183+
"ScoringModel",
184+
configNamespace);
185+
if (provider != null) {
186+
selectedScoringProducer.produce(new SelectedScoringModelProviderBuildItem(provider, modelName));
187+
}
188+
}
189+
153190
for (String modelName : requestEmbeddingModels) {
154191
Optional<String> userSelectedProvider;
155192
String configNamespace;

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
1515
import dev.langchain4j.model.moderation.ModerationModel;
1616
import dev.langchain4j.model.output.structured.Description;
17+
import dev.langchain4j.model.scoring.ScoringModel;
1718
import dev.langchain4j.rag.RetrievalAugmentor;
1819
import dev.langchain4j.retriever.Retriever;
1920
import dev.langchain4j.service.AiServices;
@@ -38,6 +39,7 @@
3839
public class LangChain4jDotNames {
3940
public static final DotName CHAT_MODEL = DotName.createSimple(ChatLanguageModel.class);
4041
public static final DotName STREAMING_CHAT_MODEL = DotName.createSimple(StreamingChatLanguageModel.class);
42+
public static final DotName SCORING_MODEL = DotName.createSimple(ScoringModel.class);
4143
public static final DotName EMBEDDING_MODEL = DotName.createSimple(EmbeddingModel.class);
4244
public static final DotName MODERATION_MODEL = DotName.createSimple(ModerationModel.class);
4345
public static final DotName IMAGE_MODEL = DotName.createSimple(ImageModel.class);

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ interface BaseConfig {
4949
*/
5050
ChatModelConfig chatModel();
5151

52+
/**
53+
* Rerank model
54+
*/
55+
ScoringModelConfig scoringModel();
56+
5257
/**
5358
* Embedding model
5459
*/
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package io.quarkiverse.langchain4j.deployment.config;
2+
3+
import java.util.Optional;
4+
5+
import io.quarkus.runtime.annotations.ConfigGroup;
6+
7+
@ConfigGroup
8+
public interface ScoringModelConfig {
9+
/**
10+
* The model provider to use
11+
*/
12+
Optional<String> provider();
13+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.deployment.items;
2+
3+
import io.quarkus.builder.item.MultiBuildItem;
4+
5+
public final class ScoringModelProviderCandidateBuildItem extends MultiBuildItem implements ProviderHolder {
6+
7+
private final String provider;
8+
9+
public ScoringModelProviderCandidateBuildItem(String provider) {
10+
this.provider = provider;
11+
}
12+
13+
public String getProvider() {
14+
return provider;
15+
}
16+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package io.quarkiverse.langchain4j.deployment.items;
2+
3+
import io.quarkus.builder.item.MultiBuildItem;
4+
5+
public final class SelectedScoringModelProviderBuildItem extends MultiBuildItem {
6+
7+
private final String provider;
8+
private final String configName;
9+
10+
public SelectedScoringModelProviderBuildItem(String provider, String configName) {
11+
this.provider = provider;
12+
this.configName = configName;
13+
}
14+
15+
public String getProvider() {
16+
return provider;
17+
}
18+
19+
public String getConfigName() {
20+
return configName;
21+
}
22+
}

0 commit comments

Comments
 (0)