Skip to content

Commit 616d02e

Browse files
committed
Update Azure OpenAI to check ModelAuthProvider during Embedding and Image model creation
1 parent b53410f commit 616d02e

File tree

3 files changed

+28
-18
lines changed

3 files changed

+28
-18
lines changed

model-providers/openai/azure-openai/deployment/src/main/java/io/quarkiverse/langchain4j/azure/openai/deployment/AzureOpenAiProcessor.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,17 @@ void generateBeans(AzureOpenAiRecorder recorder,
121121
for (var selected : selectedEmbedding) {
122122
if (PROVIDER.equals(selected.getProvider())) {
123123
String configName = selected.getConfigName();
124+
125+
var embeddingModel = recorder.embeddingModel(config, configName);
124126
var builder = SyntheticBeanBuildItem
125127
.configure(EMBEDDING_MODEL)
126128
.setRuntimeInit()
127129
.unremovable()
128130
.defaultBean()
129131
.scope(ApplicationScoped.class)
130-
.supplier(recorder.embeddingModel(config, configName));
132+
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
133+
new Type[] { ClassType.create(DotNames.MODEL_AUTH_PROVIDER) }, null))
134+
.createWith(embeddingModel);
131135
addQualifierIfNecessary(builder, configName);
132136
beanProducer.produce(builder.done());
133137
}
@@ -136,12 +140,16 @@ void generateBeans(AzureOpenAiRecorder recorder,
136140
for (var selected : selectedImage) {
137141
if (PROVIDER.equals(selected.getProvider())) {
138142
String configName = selected.getConfigName();
143+
144+
var imageModel = recorder.imageModel(config, configName);
139145
var builder = SyntheticBeanBuildItem
140146
.configure(IMAGE_MODEL)
141147
.setRuntimeInit()
142148
.defaultBean()
143149
.scope(ApplicationScoped.class)
144-
.supplier(recorder.imageModel(config, configName));
150+
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
151+
new Type[] { ClassType.create(DotNames.MODEL_AUTH_PROVIDER) }, null))
152+
.createWith(imageModel);
145153
addQualifierIfNecessary(builder, configName);
146154
beanProducer.produce(builder.done());
147155
}

model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,14 @@ public StreamingChatLanguageModel apply(SyntheticCreationalContext<StreamingChat
153153
}
154154
}
155155

156-
public Supplier<EmbeddingModel> embeddingModel(LangChain4jAzureOpenAiConfig runtimeConfig, String configName) {
156+
public Function<SyntheticCreationalContext<EmbeddingModel>, EmbeddingModel> embeddingModel(
157+
LangChain4jAzureOpenAiConfig runtimeConfig, String configName) {
157158
LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName);
158159

159160
if (azureAiConfig.enableIntegration()) {
160161
EmbeddingModelConfig embeddingModelConfig = azureAiConfig.embeddingModel();
161162
String apiKey = azureAiConfig.apiKey().orElse(null);
162163
String adToken = azureAiConfig.adToken().orElse(null);
163-
if (apiKey == null && adToken == null) {
164-
throw new ConfigValidationException(createKeyMisconfigurationProblem(configName));
165-
}
166164
var builder = AzureOpenAiEmbeddingModel.builder()
167165
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.EMBEDDING))
168166
.apiKey(apiKey)
@@ -174,29 +172,31 @@ public Supplier<EmbeddingModel> embeddingModel(LangChain4jAzureOpenAiConfig runt
174172
.logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), azureAiConfig.logRequests()))
175173
.logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), azureAiConfig.logResponses()));
176174

177-
return new Supplier<>() {
175+
return new Function<>() {
178176
@Override
179-
public EmbeddingModel get() {
177+
public EmbeddingModel apply(SyntheticCreationalContext<EmbeddingModel> context) {
178+
throwIfApiKeysNotConfigured(apiKey, adToken, isAuthProviderAvailable(context, configName),
179+
configName);
180180
return builder.build();
181181
}
182182
};
183183
} else {
184-
return new Supplier<>() {
184+
return new Function<>() {
185185
@Override
186-
public EmbeddingModel get() {
186+
public EmbeddingModel apply(SyntheticCreationalContext<EmbeddingModel> context) {
187187
return new DisabledEmbeddingModel();
188188
}
189189
};
190190
}
191191
}
192192

193-
public Supplier<ImageModel> imageModel(LangChain4jAzureOpenAiConfig runtimeConfig, String configName) {
193+
public Function<SyntheticCreationalContext<ImageModel>, ImageModel> imageModel(LangChain4jAzureOpenAiConfig runtimeConfig,
194+
String configName) {
194195
LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName);
195196

196197
if (azureAiConfig.enableIntegration()) {
197198
var apiKey = azureAiConfig.apiKey().orElse(null);
198199
String adToken = azureAiConfig.adToken().orElse(null);
199-
throwIfApiKeysNotConfigured(apiKey, adToken, false, configName);
200200

201201
var imageModelConfig = azureAiConfig.imageModel();
202202
var builder = AzureOpenAiImageModel.builder()
@@ -236,16 +236,18 @@ public Optional<? extends Path> get() {
236236

237237
builder.persistDirectory(persistDirectory);
238238

239-
return new Supplier<>() {
239+
return new Function<>() {
240240
@Override
241-
public ImageModel get() {
241+
public ImageModel apply(SyntheticCreationalContext<ImageModel> context) {
242+
throwIfApiKeysNotConfigured(apiKey, adToken, isAuthProviderAvailable(context, configName),
243+
configName);
242244
return builder.build();
243245
}
244246
};
245247
} else {
246-
return new Supplier<>() {
248+
return new Function<>() {
247249
@Override
248-
public ImageModel get() {
250+
public ImageModel apply(SyntheticCreationalContext<ImageModel> context) {
249251
return new DisabledImageModel();
250252
}
251253
};

model-providers/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/DisabledModelsAzureOpenAiRecorderTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ void disabledStreamingChatModel() {
4545

4646
@Test
4747
void disabledEmbeddingModel() {
48-
assertThat(recorder.embeddingModel(config, NamedConfigUtil.DEFAULT_NAME).get())
48+
assertThat(recorder.embeddingModel(config, NamedConfigUtil.DEFAULT_NAME).apply(null))
4949
.isNotNull()
5050
.isExactlyInstanceOf(DisabledEmbeddingModel.class);
5151
}
5252

5353
@Test
5454
void disabledImageModel() {
55-
assertThat(recorder.imageModel(config, NamedConfigUtil.DEFAULT_NAME).get())
55+
assertThat(recorder.imageModel(config, NamedConfigUtil.DEFAULT_NAME).apply(null))
5656
.isNotNull()
5757
.isExactlyInstanceOf(DisabledImageModel.class);
5858
}

0 commit comments

Comments
 (0)