Skip to content

Commit aa9ac70

Browse files
andreadimaiojmartisk
authored andcommitted
Update Cohere provider
1 parent 4b321bb commit aa9ac70

File tree

8 files changed

+243
-82
lines changed

8 files changed

+243
-82
lines changed
Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
package io.quarkiverse.langchain4j.cohere;
22

3+
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SCORING_MODEL;
4+
5+
import java.util.List;
6+
37
import jakarta.enterprise.context.ApplicationScoped;
48

9+
import org.jboss.jandex.AnnotationInstance;
510
import org.jboss.jandex.ClassType;
6-
import org.jboss.jandex.DotName;
711

812
import dev.langchain4j.model.scoring.ScoringModel;
13+
import io.quarkiverse.langchain4j.ModelName;
914
import io.quarkiverse.langchain4j.cohere.runtime.CohereRecorder;
1015
import io.quarkiverse.langchain4j.cohere.runtime.QuarkusCohereScoringModel;
11-
import io.quarkiverse.langchain4j.cohere.runtime.config.CohereConfig;
16+
import io.quarkiverse.langchain4j.cohere.runtime.config.Langchain4jCohereConfig;
17+
import io.quarkiverse.langchain4j.deployment.items.ScoringModelProviderCandidateBuildItem;
18+
import io.quarkiverse.langchain4j.deployment.items.SelectedScoringModelProviderBuildItem;
19+
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
1220
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
1321
import io.quarkus.deployment.annotations.BuildProducer;
1422
import io.quarkus.deployment.annotations.BuildStep;
@@ -18,32 +26,52 @@
1826

1927
public class CohereProcessor {
2028

21-
public static final DotName COHERE_SCORING_MODEL = DotName.createSimple(QuarkusCohereScoringModel.class);
22-
2329
static final String FEATURE = "langchain4j-cohere";
30+
private static final String PROVIDER = "cohere";
2431

2532
@BuildStep
2633
FeatureBuildItem feature() {
2734
return new FeatureBuildItem(FEATURE);
2835
}
2936

37+
@BuildStep
38+
public void providerCandidates(BuildProducer<ScoringModelProviderCandidateBuildItem> scoringProducer,
39+
LangChain4jCohereBuildConfig config) {
40+
41+
if (config.scoringModel().enabled().isEmpty() || config.scoringModel().enabled().get()) {
42+
scoringProducer.produce(new ScoringModelProviderCandidateBuildItem(PROVIDER));
43+
}
44+
}
45+
3046
@BuildStep
3147
@Record(ExecutionTime.RUNTIME_INIT)
3248
public void createScoringModelBean(
3349
BuildProducer<SyntheticBeanBuildItem> beanProducer,
50+
List<SelectedScoringModelProviderBuildItem> selectedScoring,
3451
CohereRecorder recorder,
35-
CohereConfig config) {
36-
// TODO: maybe add some kind of ScoringModelBuildItem class and produce it here
37-
beanProducer.produce(SyntheticBeanBuildItem
38-
.configure(COHERE_SCORING_MODEL)
39-
.types(ClassType.create(ScoringModel.class),
40-
ClassType.create(QuarkusCohereScoringModel.class))
41-
.defaultBean()
42-
.setRuntimeInit()
43-
.defaultBean()
44-
.scope(ApplicationScoped.class)
45-
.supplier(recorder.cohereScoringModelSupplier(config))
46-
.done());
52+
Langchain4jCohereConfig config) {
53+
54+
for (var selected : selectedScoring) {
55+
if (PROVIDER.equals(selected.getProvider())) {
56+
String configName = selected.getConfigName();
57+
var builder = SyntheticBeanBuildItem
58+
.configure(SCORING_MODEL)
59+
.types(ClassType.create(ScoringModel.class),
60+
ClassType.create(QuarkusCohereScoringModel.class))
61+
.setRuntimeInit()
62+
.defaultBean()
63+
.unremovable()
64+
.scope(ApplicationScoped.class)
65+
.supplier(recorder.cohereScoringModelSupplier(config, configName));
66+
addQualifierIfNecessary(builder, configName);
67+
beanProducer.produce(builder.done());
68+
}
69+
}
4770
}
4871

72+
private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String configName) {
73+
if (!NamedConfigUtil.isDefault(configName)) {
74+
builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", configName).build());
75+
}
76+
}
4977
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.cohere;
2+
3+
import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;
4+
5+
import io.quarkus.runtime.annotations.ConfigRoot;
6+
import io.smallrye.config.ConfigMapping;
7+
8+
@ConfigRoot(phase = BUILD_TIME)
9+
@ConfigMapping(prefix = "quarkus.langchain4j.cohere")
10+
public interface LangChain4jCohereBuildConfig {
11+
12+
/**
13+
* Scoring model related settings.
14+
*/
15+
ScoringModelBuildConfig scoringModel();
16+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package io.quarkiverse.langchain4j.cohere;
2+
3+
import java.util.Optional;
4+
5+
import io.quarkus.runtime.annotations.ConfigDocDefault;
6+
import io.quarkus.runtime.annotations.ConfigGroup;
7+
8+
@ConfigGroup
9+
public interface ScoringModelBuildConfig {
10+
11+
/**
12+
* Whether the scoring model should be enabled.
13+
*/
14+
@ConfigDocDefault("true")
15+
Optional<Boolean> enabled();
16+
}
Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,74 @@
11
package io.quarkiverse.langchain4j.cohere.runtime;
22

3+
import java.util.ArrayList;
4+
import java.util.List;
35
import java.util.function.Supplier;
46

57
import dev.langchain4j.model.scoring.ScoringModel;
6-
import io.quarkiverse.langchain4j.cohere.runtime.config.CohereConfig;
8+
import io.quarkiverse.langchain4j.cohere.runtime.config.Langchain4jCohereConfig;
9+
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
710
import io.quarkus.runtime.annotations.Recorder;
11+
import io.smallrye.config.ConfigValidationException;
812

913
@Recorder
1014
public class CohereRecorder {
1115

12-
public Supplier<ScoringModel> cohereScoringModelSupplier(CohereConfig config) {
16+
private static final String DUMMY_API_KEY = "dummy";
17+
private static final ConfigValidationException.Problem[] EMPTY_PROBLEMS = new ConfigValidationException.Problem[0];
18+
19+
public Supplier<ScoringModel> cohereScoringModelSupplier(Langchain4jCohereConfig runtimeConfig, String configName) {
20+
Langchain4jCohereConfig.CohereConfig cohereConfig = correspondingCohereConfig(runtimeConfig, configName);
21+
22+
var configProblems = checkConfigurations(cohereConfig, configName);
23+
24+
if (!configProblems.isEmpty()) {
25+
throw new ConfigValidationException(configProblems.toArray(EMPTY_PROBLEMS));
26+
}
27+
1328
return new Supplier<>() {
1429
@Override
1530
public ScoringModel get() {
1631
return new QuarkusCohereScoringModel(
17-
config.baseUrl(),
18-
config.apiKey(),
19-
config.scoringModel().modelName(),
20-
config.scoringModel().timeout(),
21-
config.scoringModel().maxRetries());
32+
cohereConfig.baseUrl(),
33+
cohereConfig.apiKey(),
34+
cohereConfig.scoringModel().modelName(),
35+
cohereConfig.scoringModel().timeout(),
36+
cohereConfig.scoringModel().maxRetries());
2237
}
2338
};
2439
}
2540

41+
private List<ConfigValidationException.Problem> checkConfigurations(Langchain4jCohereConfig.CohereConfig cohereConfig,
42+
String configName) {
43+
List<ConfigValidationException.Problem> configProblems = new ArrayList<>();
44+
45+
String apiKey = cohereConfig.apiKey();
46+
if (DUMMY_API_KEY.equals(apiKey)) {
47+
configProblems.add(createApiKeyConfigProblem(configName));
48+
}
49+
50+
return configProblems;
51+
}
52+
53+
private ConfigValidationException.Problem createApiKeyConfigProblem(String configName) {
54+
return createConfigProblem("api-key", configName);
55+
}
56+
57+
private static ConfigValidationException.Problem createConfigProblem(String key, String configName) {
58+
return new ConfigValidationException.Problem(String.format(
59+
"SRCFG00014: The config property quarkus.langchain4j.cohere%s%s is required but it could not be found in any config source",
60+
NamedConfigUtil.isDefault(configName) ? "." : ("." + configName + "."), key));
61+
}
62+
63+
private Langchain4jCohereConfig.CohereConfig correspondingCohereConfig(
64+
Langchain4jCohereConfig runtimeConfig,
65+
String configName) {
66+
Langchain4jCohereConfig.CohereConfig cohereConfig;
67+
if (NamedConfigUtil.isDefault(configName)) {
68+
cohereConfig = runtimeConfig.defaultConfig();
69+
} else {
70+
cohereConfig = runtimeConfig.namedConfig().get(configName);
71+
}
72+
return cohereConfig;
73+
}
2674
}

model-providers/cohere/runtime/src/main/java/io/quarkiverse/langchain4j/cohere/runtime/config/CohereConfig.java

Lines changed: 0 additions & 28 deletions
This file was deleted.

model-providers/cohere/runtime/src/main/java/io/quarkiverse/langchain4j/cohere/runtime/config/CohereScoringModelConfig.java

Lines changed: 0 additions & 31 deletions
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package io.quarkiverse.langchain4j.cohere.runtime.config;
2+
3+
import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME;
4+
5+
import java.time.Duration;
6+
import java.util.Map;
7+
import java.util.Optional;
8+
9+
import io.quarkus.runtime.annotations.ConfigDocDefault;
10+
import io.quarkus.runtime.annotations.ConfigDocMapKey;
11+
import io.quarkus.runtime.annotations.ConfigDocSection;
12+
import io.quarkus.runtime.annotations.ConfigGroup;
13+
import io.quarkus.runtime.annotations.ConfigRoot;
14+
import io.smallrye.config.ConfigMapping;
15+
import io.smallrye.config.WithDefault;
16+
import io.smallrye.config.WithDefaults;
17+
import io.smallrye.config.WithParentName;
18+
19+
@ConfigRoot(phase = RUN_TIME)
20+
@ConfigMapping(prefix = "quarkus.langchain4j.cohere")
21+
public interface Langchain4jCohereConfig {
22+
23+
/**
24+
* Default model config.
25+
*/
26+
@WithParentName
27+
CohereConfig defaultConfig();
28+
29+
/**
30+
* Named model config.
31+
*/
32+
@ConfigDocSection
33+
@ConfigDocMapKey("model-name")
34+
@WithParentName
35+
@WithDefaults
36+
Map<String, CohereConfig> namedConfig();
37+
38+
@ConfigGroup
39+
interface CohereConfig {
40+
41+
/**
42+
* Base URL of the Cohere API.
43+
*/
44+
@WithDefault("https://api.cohere.ai/")
45+
String baseUrl();
46+
47+
/**
48+
* Cohere API key.
49+
*/
50+
@WithDefault("dummy")
51+
String apiKey();
52+
53+
/**
54+
* Timeout for Cohere calls.
55+
*/
56+
@ConfigDocDefault("30s")
57+
@WithDefault("${quarkus.langchain4j.timeout}")
58+
Optional<Duration> timeout();
59+
60+
/**
61+
* Scoring model config.
62+
*/
63+
ScoringModelConfig scoringModel();
64+
}
65+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package io.quarkiverse.langchain4j.cohere.runtime.config;
2+
3+
import java.time.Duration;
4+
import java.util.Optional;
5+
6+
import io.quarkus.runtime.annotations.ConfigDocDefault;
7+
import io.quarkus.runtime.annotations.ConfigGroup;
8+
import io.smallrye.config.WithDefault;
9+
10+
@ConfigGroup
11+
public interface ScoringModelConfig {
12+
13+
/**
14+
* Reranking model to use. The current list of supported models can be found in the
15+
* <a href="https://docs.cohere.com/docs/models">Cohere docs</a>
16+
*/
17+
@WithDefault("rerank-multilingual-v2.0")
18+
String modelName();
19+
20+
/**
21+
* Timeout for Cohere calls
22+
*/
23+
@WithDefault("30s")
24+
Duration timeout();
25+
26+
/**
27+
* Whether embedding model requests should be logged.
28+
*/
29+
@ConfigDocDefault("false")
30+
Optional<Boolean> logRequests();
31+
32+
/**
33+
* Whether embedding model responses should be logged.
34+
*/
35+
@ConfigDocDefault("false")
36+
Optional<Boolean> logResponses();
37+
38+
/**
39+
* The maximum number of times to retry. 1 means exactly one attempt, with retrying disabled.
40+
*
41+
* @deprecated Using the fault tolerance mechanisms built in Langchain4j is not recommended. If possible, use MicroProfile
42+
* Fault
43+
* Tolerance instead.
44+
*/
45+
@WithDefault("1")
46+
Integer maxRetries();
47+
}

0 commit comments

Comments
 (0)