Skip to content

Commit 9b0f287

Browse files
authored
Merge pull request #1131 from quarkiverse/#1127
Fix ModelAuthProvider for streaming chat model in Azure
2 parents 9934f65 + 170ff3c commit 9b0f287

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

model-providers/openai/azure-openai/deployment/src/main/java/io/quarkiverse/langchain4j/azure/openai/deployment/AzureOpenAiProcessor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ void generateBeans(AzureOpenAiRecorder recorder,
110110
.scope(ApplicationScoped.class)
111111
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
112112
new Type[] { ClassType.create(DotNames.CHAT_MODEL_LISTENER) }, null))
113+
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
114+
new Type[] { ClassType.create(DotNames.MODEL_AUTH_PROVIDER) }, null))
113115
.createWith(streamingChatModel);
114116
addQualifierIfNecessary(streamingBuilder, configName);
115117
beanProducer.produce(streamingBuilder.done());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package io.quarkiverse.langchain4j.azure.openai.test;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
5+
import jakarta.enterprise.context.ApplicationScoped;
6+
import jakarta.inject.Inject;
7+
8+
import org.jboss.shrinkwrap.api.ShrinkWrap;
9+
import org.jboss.shrinkwrap.api.spec.JavaArchive;
10+
import org.junit.jupiter.api.Test;
11+
import org.junit.jupiter.api.extension.RegisterExtension;
12+
13+
import dev.langchain4j.model.chat.ChatLanguageModel;
14+
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
15+
import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
16+
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiChatModel;
17+
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiStreamingChatModel;
18+
import io.quarkiverse.langchain4j.openai.testing.internal.OpenAiBaseTest;
19+
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
20+
import io.quarkus.arc.ClientProxy;
21+
import io.quarkus.test.QuarkusUnitTest;
22+
23+
public class ModelAuthProviderSmokeTest extends OpenAiBaseTest {
24+
25+
@RegisterExtension
26+
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
27+
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
28+
.overrideRuntimeConfigKey("quarkus.langchain4j.azure-openai.endpoint", WiremockAware.wiremockUrlForConfig("/v1"));
29+
30+
@Inject
31+
ChatLanguageModel chatLanguageModel;
32+
33+
@Inject
34+
StreamingChatLanguageModel streamingChatLanguageModel;
35+
36+
@Test
37+
void test() {
38+
assertThat(ClientProxy.unwrap(chatLanguageModel)).isInstanceOf(AzureOpenAiChatModel.class);
39+
assertThat(ClientProxy.unwrap(streamingChatLanguageModel)).isInstanceOf(AzureOpenAiStreamingChatModel.class);
40+
}
41+
42+
@ApplicationScoped
43+
public static class DummyModelAuthProvider implements ModelAuthProvider {
44+
45+
@Override
46+
public String getAuthorization(Input input) {
47+
return "dummy";
48+
}
49+
}
50+
}

0 commit comments

Comments
 (0)