Skip to content

Commit fb8bf24

Browse files
authored
Merge pull request #1006 from quarkiverse/#993
Add support for TLS configuration name
2 parents 4bbcc72 + ba6187d commit fb8bf24

File tree

24 files changed

+483
-35
lines changed

24 files changed

+483
-35
lines changed

model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ public class OllamaChatLanguageModel implements ChatLanguageModel {
4343

4444
private OllamaChatLanguageModel(Builder builder) {
4545
client = new OllamaClient(builder.baseUrl, builder.timeout, builder.logRequests, builder.logResponses,
46-
builder.configName);
46+
builder.configName, builder.tlsConfigurationName);
4747
model = builder.model;
4848
format = builder.format;
4949
options = builder.options;
50-
this.listeners = builder.listeners;
50+
listeners = builder.listeners;
5151
}
5252

5353
public static Builder builder() {
@@ -194,6 +194,7 @@ private ChatModelResponse createModelListenerResponse(String responseId,
194194

195195
public static final class Builder {
196196
private String baseUrl = "http://localhost:11434";
197+
private String tlsConfigurationName;
197198
private Duration timeout = Duration.ofSeconds(10);
198199
private String model;
199200
private String format;
@@ -212,6 +213,11 @@ public Builder baseUrl(String val) {
212213
return this;
213214
}
214215

216+
public Builder tlsConfigurationName(String tlsConfigurationName) {
217+
this.tlsConfigurationName = tlsConfigurationName;
218+
return this;
219+
}
220+
215221
public Builder timeout(Duration val) {
216222
this.timeout = val;
217223
return this;

model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaClient.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,23 @@
66
import java.util.Optional;
77
import java.util.concurrent.TimeUnit;
88

9+
import jakarta.enterprise.inject.Instance;
10+
import jakarta.enterprise.inject.spi.CDI;
11+
912
import org.jboss.resteasy.reactive.client.api.LoggingScope;
1013

1114
import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
1215
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
16+
import io.quarkus.tls.TlsConfiguration;
17+
import io.quarkus.tls.TlsConfigurationRegistry;
1318
import io.smallrye.mutiny.Multi;
1419

1520
public class OllamaClient {
1621

1722
private final OllamaRestApi restApi;
1823

19-
public OllamaClient(String baseUrl, Duration timeout, boolean logRequests, boolean logResponses, String configName) {
24+
public OllamaClient(String baseUrl, Duration timeout, boolean logRequests, boolean logResponses, String configName,
25+
String tlsConfigurationName) {
2026
try {
2127
// TODO: cache?
2228
QuarkusRestClientBuilder builder = QuarkusRestClientBuilder.newBuilder()
@@ -32,6 +38,11 @@ public OllamaClient(String baseUrl, Duration timeout, boolean logRequests, boole
3238
if (maybeModelAuthProvider.isPresent()) {
3339
builder.register(new OllamaRestApi.OllamaRestAPIFilter(maybeModelAuthProvider.get()));
3440
}
41+
Instance<TlsConfigurationRegistry> tlsConfigurationRegistry = CDI.current().select(TlsConfigurationRegistry.class);
42+
if (tlsConfigurationRegistry.isResolvable()) {
43+
TlsConfiguration.from(tlsConfigurationRegistry.get(), Optional.ofNullable(tlsConfigurationName))
44+
.ifPresent(builder::tlsConfiguration);
45+
}
3546

3647
restApi = builder.build(OllamaRestApi.class);
3748
} catch (URISyntaxException e) {

model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaEmbeddingModel.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public class OllamaEmbeddingModel implements EmbeddingModel {
1616

1717
private OllamaEmbeddingModel(Builder builder) {
1818
client = new OllamaClient(builder.baseUrl, builder.timeout, builder.logRequests, builder.logResponses,
19-
builder.configName);
19+
builder.configName, builder.tlsConfigurationName);
2020
model = builder.model;
2121
}
2222

@@ -44,6 +44,7 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
4444

4545
public static final class Builder {
4646
private String baseUrl = "http://localhost:11434";
47+
private String tlsConfigurationName;
4748
private Duration timeout = Duration.ofSeconds(10);
4849
private String model;
4950

@@ -59,6 +60,11 @@ public Builder baseUrl(String val) {
5960
return this;
6061
}
6162

63+
public Builder tlsConfigurationName(String tlsConfigurationName) {
64+
this.tlsConfigurationName = tlsConfigurationName;
65+
return this;
66+
}
67+
6268
public Builder timeout(Duration val) {
6369
this.timeout = val;
6470
return this;

model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaStreamingChatLanguageModel.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public class OllamaStreamingChatLanguageModel implements StreamingChatLanguageMo
2626

2727
private OllamaStreamingChatLanguageModel(OllamaStreamingChatLanguageModel.Builder builder) {
2828
client = new OllamaClient(builder.baseUrl, builder.timeout, builder.logRequests, builder.logResponses,
29-
builder.configName);
29+
builder.configName, builder.tlsConfigurationName);
3030
model = builder.model;
3131
format = builder.format;
3232
options = builder.options;
@@ -101,6 +101,7 @@ private Builder() {
101101
}
102102

103103
private String baseUrl = "http://localhost:11434";
104+
private String tlsConfigurationName;
104105
private Duration timeout = Duration.ofSeconds(10);
105106
private String model;
106107
private String format;
@@ -115,6 +116,11 @@ public Builder baseUrl(String val) {
115116
return this;
116117
}
117118

119+
public Builder tlsConfigurationName(String tlsConfigurationName) {
120+
this.tlsConfigurationName = tlsConfigurationName;
121+
return this;
122+
}
123+
118124
public Builder timeout(Duration val) {
119125
this.timeout = val;
120126
return this;

model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ public Function<SyntheticCreationalContext<ChatLanguageModel>, ChatLanguageModel
6363
}
6464
var builder = OllamaChatLanguageModel.builder()
6565
.baseUrl(ollamaConfig.baseUrl().orElse(DEFAULT_BASE_URL))
66+
.tlsConfigurationName(ollamaConfig.tlsConfigurationName().orElse(null))
6667
.timeout(ollamaConfig.timeout().orElse(Duration.ofSeconds(10)))
6768
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), ollamaConfig.logRequests()))
6869
.logResponses(firstOrDefault(false, chatModelConfig.logResponses(), ollamaConfig.logResponses()))
@@ -109,6 +110,7 @@ public Supplier<EmbeddingModel> embeddingModel(LangChain4jOllamaConfig runtimeCo
109110

110111
var builder = OllamaEmbeddingModel.builder()
111112
.baseUrl(ollamaConfig.baseUrl().orElse(DEFAULT_BASE_URL))
113+
.tlsConfigurationName(ollamaConfig.tlsConfigurationName().orElse(null))
112114
.timeout(ollamaConfig.timeout().orElse(Duration.ofSeconds(10)))
113115
.model(ollamaFixedConfig.embeddingModel().modelId())
114116
.logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), ollamaConfig.logRequests()))
@@ -156,6 +158,7 @@ public Supplier<StreamingChatLanguageModel> streamingChatModel(LangChain4jOllama
156158
}
157159
var builder = OllamaStreamingChatLanguageModel.builder()
158160
.baseUrl(ollamaConfig.baseUrl().orElse(DEFAULT_BASE_URL))
161+
.tlsConfigurationName(ollamaConfig.tlsConfigurationName().orElse(null))
159162
.timeout(ollamaConfig.timeout().orElse(Duration.ofSeconds(10)))
160163
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), ollamaConfig.logRequests()))
161164
.logResponses(firstOrDefault(false, chatModelConfig.logResponses(), ollamaConfig.logResponses()))

model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/LangChain4jOllamaConfig.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ interface OllamaConfig {
4040
*/
4141
Optional<String> baseUrl();
4242

43+
/**
44+
* If set, the named TLS configuration with the configured name will be applied to the REST Client
45+
*/
46+
Optional<String> tlsConfigurationName();
47+
4348
/**
4449
* Timeout for Ollama calls
4550
*/

model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import io.quarkiverse.langchain4j.azure.openai.runtime.config.LangChain4jAzureOpenAiConfig;
3434
import io.quarkiverse.langchain4j.azure.openai.runtime.config.LangChain4jAzureOpenAiConfig.AzureAiConfig.EndpointType;
3535
import io.quarkiverse.langchain4j.openai.common.QuarkusOpenAiClient;
36+
import io.quarkiverse.langchain4j.openai.common.runtime.AdditionalPropertiesHack;
3637
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
3738
import io.quarkus.arc.SyntheticCreationalContext;
3839
import io.quarkus.runtime.ShutdownContext;
@@ -317,6 +318,7 @@ private static ConfigValidationException.Problem createConfigProblem(String key,
317318
}
318319

319320
public void cleanUp(ShutdownContext shutdown) {
321+
AdditionalPropertiesHack.reset();
320322
shutdown.addShutdownTask(new Runnable() {
321323
@Override
322324
public void run() {

model-providers/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/common/QuarkusOpenAiClient.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77
import java.util.List;
88
import java.util.Map;
99
import java.util.Objects;
10+
import java.util.Optional;
1011
import java.util.concurrent.ConcurrentHashMap;
1112
import java.util.concurrent.TimeUnit;
1213
import java.util.concurrent.atomic.AtomicReference;
1314
import java.util.function.BiFunction;
1415
import java.util.function.Consumer;
1516
import java.util.function.Supplier;
1617

18+
import jakarta.enterprise.inject.Instance;
19+
import jakarta.enterprise.inject.spi.CDI;
1720
import jakarta.ws.rs.client.ClientRequestContext;
1821
import jakarta.ws.rs.client.ClientRequestFilter;
1922

@@ -42,7 +45,10 @@
4245
import dev.ai4j.openai4j.moderation.ModerationResult;
4346
import dev.ai4j.openai4j.spi.OpenAiClientBuilderFactory;
4447
import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
48+
import io.quarkiverse.langchain4j.openai.common.runtime.AdditionalPropertiesHack;
4549
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
50+
import io.quarkus.tls.TlsConfiguration;
51+
import io.quarkus.tls.TlsConfigurationRegistry;
4652
import io.smallrye.mutiny.Multi;
4753
import io.smallrye.mutiny.Uni;
4854
import io.smallrye.mutiny.subscription.Cancellable;
@@ -119,6 +125,13 @@ public void filter(ClientRequestContext requestContext) {
119125
.ifPresent(modelAuthProvider -> restApiBuilder
120126
.register(new OpenAiRestApi.OpenAIRestAPIFilter(modelAuthProvider)));
121127

128+
Instance<TlsConfigurationRegistry> tlsConfigurationRegistry = CDI.current()
129+
.select(TlsConfigurationRegistry.class);
130+
if (tlsConfigurationRegistry.isResolvable()) {
131+
TlsConfiguration.from(tlsConfigurationRegistry.get(), Optional.ofNullable(builder.tlsConfigurationName))
132+
.ifPresent(restApiBuilder::tlsConfiguration);
133+
}
134+
122135
return restApiBuilder.build(OpenAiRestApi.class);
123136
} catch (URISyntaxException e) {
124137
throw new RuntimeException(e);
@@ -513,7 +526,9 @@ public static class QuarkusOpenAiClientBuilderFactory implements OpenAiClientBui
513526

514527
@Override
515528
public Builder get() {
516-
return new Builder();
529+
var result = new Builder();
530+
result.tlsConfigurationName(AdditionalPropertiesHack.getAndClearTlsConfigurationName());
531+
return result;
517532
}
518533
}
519534

@@ -522,6 +537,12 @@ public static class Builder extends OpenAiClient.Builder<QuarkusOpenAiClient, Bu
522537
private String userAgent;
523538
private String azureAdToken;
524539
private String configName;
540+
private String tlsConfigurationName;
541+
542+
public Builder tlsConfigurationName(String tlsConfigurationName) {
543+
this.tlsConfigurationName = tlsConfigurationName;
544+
return this;
545+
}
525546

526547
public Builder userAgent(String userAgent) {
527548
this.userAgent = userAgent;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package io.quarkiverse.langchain4j.openai.common.runtime;
2+
3+
import java.util.HashMap;
4+
import java.util.Map;
5+
6+
/**
7+
* This is done because we have no way of passing Quarkus specific properties from a model to a client.
8+
* This only works because:
9+
* <ul>
10+
* <li>The creation of beans does not happen in parallel</li>
11+
* <li>The creation of beans happens on the same thread</li>
12+
* <li>Setting up a model builder always precedes setting up a client builder</li>
13+
* </ul>
14+
*/
15+
public final class AdditionalPropertiesHack {
16+
17+
private AdditionalPropertiesHack() {
18+
}
19+
20+
static final ThreadLocal<Map<String, String>> PROPS = new ThreadLocal<>();
21+
static {
22+
reset();
23+
}
24+
25+
public static void reset() {
26+
PROPS.set(new HashMap<>());
27+
}
28+
29+
public static void setTlsConfigurationName(String tlsConfigurationName) {
30+
Map<String, String> map = PROPS.get();
31+
if (map == null) {
32+
// this should never happen
33+
return;
34+
}
35+
map.put("tlsConfigurationName", tlsConfigurationName);
36+
}
37+
38+
public static String getAndClearTlsConfigurationName() {
39+
Map<String, String> map = PROPS.get();
40+
if (map == null) {
41+
// this should never happen
42+
return null;
43+
}
44+
return map.remove("tlsConfigurationName");
45+
}
46+
}

model-providers/openai/openai-vanilla/deployment/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@
4646
<artifactId>quarkus-smallrye-fault-tolerance</artifactId>
4747
<scope>test</scope>
4848
</dependency>
49+
<dependency>
50+
<groupId>io.smallrye.certs</groupId>
51+
<artifactId>smallrye-certificate-generator-junit5</artifactId>
52+
<version>0.8.1</version>
53+
<scope>test</scope>
54+
</dependency>
4955
</dependencies>
5056
<build>
5157
<plugins>

0 commit comments

Comments
 (0)