Skip to content

Commit 3767ba1

Browse files
authored
Merge pull request #1600 from quarkiverse/#1595
Make in-process embedding models respect select model provider
2 parents caa174c + 922c443 commit 3767ba1

File tree

6 files changed

+139
-4
lines changed

6 files changed

+139
-4
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/InProcessEmbeddingProcessor.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ record LocalEmbeddingModel(String classname, String modelName, String onnxModelP
3636

3737
private static final Logger LOGGER = Logger.getLogger(InProcessEmbeddingProcessor.class);
3838

39-
private static List<LocalEmbeddingModel> MODELS = List.of(
39+
private final static List<LocalEmbeddingModel> MODELS = List.of(
4040
new LocalEmbeddingModel("dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel",
4141
"all-minilm-l6-v2-q", "all-minilm-l6-v2-q.onnx", "all-minilm-l6-v2-q-tokenizer.json"),
4242
new LocalEmbeddingModel("dev.langchain4j.model.embedding.onnx.allminilml6v2.AllMiniLmL6V2EmbeddingModel",
@@ -101,10 +101,13 @@ void exposeInProcessEmbeddingBeans(InProcessEmbeddingRecorder recorder,
101101
BuildProducer<SyntheticBeanBuildItem> beanProducer) {
102102

103103
for (InProcessEmbeddingBuildItem embedding : embeddings) {
104-
Optional<String> modelName = selectedEmbedding.stream()
104+
Optional<SelectedEmbeddingModelCandidateBuildItem> matchingSelected = selectedEmbedding.stream()
105105
.filter(se -> se.getProvider().equals(embedding.getProvider()))
106-
.map(SelectedEmbeddingModelCandidateBuildItem::getConfigName)
107106
.findFirst();
107+
if (matchingSelected.isEmpty() && !selectedEmbedding.isEmpty()) {
108+
continue;
109+
}
110+
108111
var builder = SyntheticBeanBuildItem
109112
.configure(DotName.createSimple(embedding.className()))
110113
.types(EmbeddingModel.class)
@@ -113,7 +116,7 @@ void exposeInProcessEmbeddingBeans(InProcessEmbeddingRecorder recorder,
113116
.unremovable()
114117
.scope(ApplicationScoped.class)
115118
.supplier(recorder.instantiate(embedding.className()));
116-
modelName.ifPresent(m -> addQualifierIfNecessary(builder, m));
119+
matchingSelected.ifPresent(m -> addQualifierIfNecessary(builder, m.getConfigName()));
117120
beanProducer.produce(builder.done());
118121
}
119122
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
3+
<modelVersion>4.0.0</modelVersion>
4+
5+
<parent>
6+
<groupId>io.quarkiverse.langchain4j</groupId>
7+
<artifactId>quarkus-langchain4j-integration-tests-in-process-embedding-models</artifactId>
8+
<version>999-SNAPSHOT</version>
9+
</parent>
10+
11+
<artifactId>quarkus-langchain4j-integration-test-embed-bge-small-en-v15-and-ollama</artifactId>
12+
<name>Quarkus LangChain4j - Integration Tests - embeddings-bge-small-en-v15-and-ollama</name>
13+
14+
<dependencies>
15+
<dependency>
16+
<groupId>dev.langchain4j</groupId>
17+
<artifactId>langchain4j-embeddings-bge-small-en-v15</artifactId>
18+
<version>${langchain4j-embeddings.version}</version>
19+
</dependency>
20+
21+
<dependency>
22+
<groupId>io.quarkiverse.langchain4j</groupId>
23+
<artifactId>quarkus-langchain4j-ollama</artifactId>
24+
<version>${project.version}</version>
25+
</dependency>
26+
27+
<dependency>
28+
<groupId>io.quarkus</groupId>
29+
<artifactId>quarkus-rest-jackson</artifactId>
30+
</dependency>
31+
32+
<dependency>
33+
<groupId>io.quarkus</groupId>
34+
<artifactId>quarkus-junit5</artifactId>
35+
<scope>test</scope>
36+
</dependency>
37+
<dependency>
38+
<groupId>io.rest-assured</groupId>
39+
<artifactId>rest-assured</artifactId>
40+
<scope>test</scope>
41+
</dependency>
42+
<dependency>
43+
<groupId>org.assertj</groupId>
44+
<artifactId>assertj-core</artifactId>
45+
<version>${assertj.version}</version>
46+
<scope>test</scope>
47+
</dependency>
48+
<dependency>
49+
<groupId>io.quarkus</groupId>
50+
<artifactId>quarkus-devtools-testing</artifactId>
51+
<scope>test</scope>
52+
</dependency>
53+
</dependencies>
54+
<build>
55+
<plugins>
56+
<plugin>
57+
<groupId>io.quarkus</groupId>
58+
<artifactId>quarkus-maven-plugin</artifactId>
59+
<executions>
60+
<execution>
61+
<goals>
62+
<goal>build</goal>
63+
</goals>
64+
</execution>
65+
</executions>
66+
</plugin>
67+
<plugin>
68+
<artifactId>maven-failsafe-plugin</artifactId>
69+
<executions>
70+
<execution>
71+
<goals>
72+
<goal>integration-test</goal>
73+
<goal>verify</goal>
74+
</goals>
75+
<configuration>
76+
<systemPropertyVariables>
77+
<native.image.path>${project.build.directory}/${project.build.finalName}-runner</native.image.path>
78+
<java.util.logging.manager>org.jboss.logmanager.LogManager</java.util.logging.manager>
79+
<maven.home>${maven.home}</maven.home>
80+
</systemPropertyVariables>
81+
</configuration>
82+
</execution>
83+
</executions>
84+
</plugin>
85+
</plugins>
86+
</build>
87+
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package org.acme.test;
2+
3+
import jakarta.ws.rs.GET;
4+
import jakarta.ws.rs.Path;
5+
6+
import dev.langchain4j.model.embedding.EmbeddingModel;
7+
import io.quarkus.arc.ClientProxy;
8+
9+
@Path("embedding")
10+
public class Resource {
11+
12+
private final EmbeddingModel embeddingModel;
13+
14+
public Resource(EmbeddingModel embeddingModel) {
15+
this.embeddingModel = embeddingModel;
16+
}
17+
18+
@GET
19+
public String get() {
20+
return ClientProxy.unwrap(embeddingModel).getClass().getSimpleName();
21+
}
22+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
quarkus.langchain4j.embedding-model.provider=ollama
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package org.acme.test;
2+
3+
import static io.restassured.RestAssured.when;
4+
import static org.hamcrest.Matchers.containsString;
5+
6+
import org.junit.jupiter.api.Test;
7+
8+
import io.quarkus.test.junit.QuarkusTest;
9+
10+
@QuarkusTest
11+
class ResourceTest {
12+
13+
@Test
14+
void test() {
15+
when().get("/embedding")
16+
.then()
17+
.statusCode(200)
18+
.body(containsString("Ollama"));
19+
}
20+
21+
}

integration-tests/in-process-embedding-models/pom.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
<module>embed-bge-small-en-q</module>
1717
<module>embed-bge-small-en</module>
1818
<module>embed-bge-small-en-v15</module>
19+
<module>embed-bge-small-en-v15-and-ollama</module>
1920
<module>embed-bge-small-en-v15-q</module>
2021
<module>embed-e5-small-v2-q</module>
2122
<module>embed-e5-small-v2</module>

0 commit comments

Comments
 (0)