Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.autoconfigure.vectorstore.azure;

import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.util.ClientOptions;
import com.azure.search.documents.indexes.SearchIndexClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;

Expand Down Expand Up @@ -47,11 +48,16 @@
@ConditionalOnProperty(prefix = "spring.ai.vectorstore.azure", value = { "url", "api-key", "index-name" })
public class AzureVectorStoreAutoConfiguration {

private final static String APPLICATION_ID = "spring-ai";

@Bean
@ConditionalOnMissingBean
public SearchIndexClient searchIndexClient(AzureVectorStoreProperties properties) {
ClientOptions clientOptions = new ClientOptions();
clientOptions.setApplicationId(APPLICATION_ID);
return new SearchIndexClientBuilder().endpoint(properties.getUrl())
.credential(new AzureKeyCredential(properties.getApiKey()))
.clientOptions(clientOptions)
.buildClient();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,8 +33,20 @@
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.util.ReflectionUtils;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.implementation.OpenAIClientImpl;
import com.azure.core.http.HttpHeader;
import com.azure.core.http.HttpHeaderName;
import com.azure.core.http.HttpMethod;
import com.azure.core.http.HttpPipeline;
import com.azure.core.http.HttpRequest;
import com.azure.core.http.HttpResponse;
import reactor.core.publisher.Flux;

import java.lang.reflect.Field;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand All @@ -44,11 +56,12 @@
/**
* @author Christian Tzolov
* @author Piotr Olaszewski
* @author Soby Chacko
* @since 0.8.0
*/
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
public class AzureOpenAiAutoConfigurationIT {
class AzureOpenAiAutoConfigurationIT {

private static String CHAT_MODEL_NAME = "gpt-4o";

Expand Down Expand Up @@ -79,7 +92,7 @@ public class AzureOpenAiAutoConfigurationIT {
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");

@Test
public void chatCompletion() {
void chatCompletion() {
contextRunner.run(context -> {
AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class);
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage, systemMessage)));
Expand All @@ -88,7 +101,26 @@ public void chatCompletion() {
}

@Test
public void chatCompletionStreaming() {
void httpRequestContainsUserAgentHeader() {
contextRunner.run(context -> {
OpenAIClient openAIClient = context.getBean(OpenAIClient.class);
Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient");
assertThat(serviceClientField).isNotNull();
ReflectionUtils.makeAccessible(serviceClientField);
OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient);
assertThat(oaci).isNotNull();
HttpPipeline httpPipeline = oaci.getHttpPipeline();
HttpResponse httpResponse = httpPipeline
.send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL()))
.block();
assertThat(httpResponse).isNotNull();
HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT);
assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue();
});
}

@Test
void chatCompletionStreaming() {
contextRunner.run(context -> {

AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class);
Expand Down Expand Up @@ -140,7 +172,7 @@ void transcribe() {
}

@Test
public void chatActivation() {
void chatActivation() {

// Disable the chat auto-configuration.
contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=false").run(context -> {
Expand All @@ -159,7 +191,7 @@ public void chatActivation() {
}

@Test
public void embeddingActivation() {
void embeddingActivation() {

// Disable the embedding auto-configuration.
contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=false").run(context -> {
Expand All @@ -178,7 +210,7 @@ public void embeddingActivation() {
}

@Test
public void audioTranscriptionActivation() {
void audioTranscriptionActivation() {

// Disable the transcription auto-configuration.
contextRunner.withPropertyValues("spring.ai.azure.openai.audio.transcription.enabled=false").run(context -> {
Expand Down