Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
Expand Down Expand Up @@ -72,7 +73,9 @@
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -271,64 +274,90 @@ public ChatResponse call(Prompt prompt) {

@Override
public Flux<ChatResponse> stream(Prompt prompt) {

ChatCompletionRequest request = createRequest(prompt, true);

Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.retryTemplate
.execute(ctx -> this.openAiApi.chatCompletionStream(request, getAdditionalHttpHeaders(prompt)));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
// the function call handling logic.
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
try {
@SuppressWarnings("null")
String id = chatCompletion2.id();

// @formatter:off
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
if (choice.message().role() != null) {
roleMap.putIfAbsent(id, choice.message().role().name());
}
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");
return buildGeneration(choice, metadata);
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request,
getAdditionalHttpHeaders(prompt));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();

final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.operationMetadata(buildOperationMetadata())
.requestOptions(buildRequestOptions(request))
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
// the function call handling logic.
Flux<ChatResponse> chatResponse = completionChunks.map(this::chunkToChatCompletion)
.switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> {
try {
@SuppressWarnings("null")
String id = chatCompletion2.id();

List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {// @formatter:off

if (choice.message().role() != null) {
roleMap.putIfAbsent(id, choice.message().role().name());
}
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "");

return buildGeneration(choice, metadata);
}).toList();
// @formatter:on
// @formatter:on

if (chatCompletion2.usage() != null) {
return new ChatResponse(generations, from(chatCompletion2, null));
}
else {
return new ChatResponse(generations);
catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
}
}
catch (Exception e) {
logger.error("Error processing chat completion", e);
return new ChatResponse(List.of());
}

}));
}));

return chatResponse.flatMap(response -> {
// @formatter:off
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {

if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
OpenAiApi.ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
else {
return Flux.just(response);
}
})
.doOnError(observation::error)
.doFinally(s -> {
// TODO: Consider a custom ObservationContext and
// include additional metadata
// if (s == SignalType.CANCEL) {
// observationContext.setAborted(true);
// }
observation.stop();
})
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on

return new MessageAggregator().aggregate(flux, mergedChatResponse -> {
observationContext.setResponse(mergedChatResponse);
});

if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
OpenAiApi.ChatCompletionFinishReason.STOP.name()))) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}
else {
return Flux.just(response);
}
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public void streamUserMessageSimpleContentType() {

when(openAiApi.chatCompletionStream(pomptCaptor.capture(), headersCaptor.capture())).thenReturn(fluxResponse);

chatModel.stream(new Prompt(List.of(new UserMessage("test message"))));
chatModel.stream(new Prompt(List.of(new UserMessage("test message")))).subscribe();

validateStringContent(pomptCaptor.getValue());
assertThat(headersCaptor.getValue()).isEmpty();
Expand Down Expand Up @@ -137,8 +137,10 @@ public void streamUserMessageWithMediaType() throws MalformedURLException {
when(openAiApi.chatCompletionStream(pomptCaptor.capture(), headersCaptor.capture())).thenReturn(fluxResponse);

URL mediaUrl = new URL("http://test");
chatModel.stream(new Prompt(
List.of(new UserMessage("test message", List.of(new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl))))));
chatModel
.stream(new Prompt(
List.of(new UserMessage("test message", List.of(new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl))))))
.subscribe();

validateComplexContent(pomptCaptor.getValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import io.micrometer.common.KeyValue;
import io.micrometer.observation.tck.TestObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistryAssert;
import reactor.core.publisher.Flux;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
Expand All @@ -37,6 +40,7 @@
import org.springframework.retry.support.RetryTemplate;

import java.util.List;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames;
Expand All @@ -57,8 +61,14 @@ public class OpenAiChatModelObservationIT {
@Autowired
OpenAiChatModel chatModel;

@BeforeEach
void beforeEach() {
observationRegistry.clear();
}

@Test
void observationForEmbeddingOperation() {
void observationForChatOperation() {

var options = OpenAiChatOptions.builder()
.withModel(OpenAiApi.ChatModel.GPT_4_O_MINI.getValue())
.withFrequencyPenalty(0f)
Expand All @@ -77,6 +87,45 @@ void observationForEmbeddingOperation() {
ChatResponseMetadata responseMetadata = chatResponse.getMetadata();
assertThat(responseMetadata).isNotNull();

validate(responseMetadata);
}

@Test
void observationForStreamingChatOperation() {
var options = OpenAiChatOptions.builder()
.withModel(OpenAiApi.ChatModel.GPT_4_O_MINI.getValue())
.withFrequencyPenalty(0f)
.withMaxTokens(2048)
.withPresencePenalty(0f)
.withStop(List.of("this-is-the-end"))
.withTemperature(0.7f)
.withTopP(1f)
.withStreamUsage(true)
.build();

Prompt prompt = new Prompt("Why does a raven look like a desk?", options);

Flux<ChatResponse> chatResponseFlux = chatModel.stream(prompt);

List<ChatResponse> responses = chatResponseFlux.collectList().block();
assertThat(responses).isNotEmpty();
assertThat(responses).hasSizeGreaterThan(10);

String aggregatedResponse = responses.subList(0, responses.size() - 1)
.stream()
.map(r -> r.getResult().getOutput().getContent())
.collect(Collectors.joining());
assertThat(aggregatedResponse).isNotEmpty();

ChatResponse lastChatResponse = responses.get(responses.size() - 1);

ChatResponseMetadata responseMetadata = lastChatResponse.getMetadata();
assertThat(responseMetadata).isNotNull();

validate(responseMetadata);
}

private void validate(ChatResponseMetadata responseMetadata) {
TestObservationRegistryAssert.assertThat(observationRegistry)
.doesNotHaveAnyRemainingCurrentObservation()
.hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@
* @author Christian Tzolov
*/
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*")
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
public class OpenAiPaymentTransactionIT {

private final static Logger logger = LoggerFactory.getLogger(OpenAiPaymentTransactionIT.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Optional;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand Down Expand Up @@ -163,6 +164,7 @@ public void openAiChatNonTransientError() {
}

@Test
@Disabled("Currently stream() does not implmement retry")
public void openAiChatStreamTransientError() {

var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0,
Expand All @@ -184,10 +186,11 @@ public void openAiChatStreamTransientError() {
}

@Test
@Disabled("Currently stream() does not implmement retry")
public void openAiChatStreamNonTransientError() {
when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any()))
.thenThrow(new RuntimeException("Non Transient Error"));
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")));
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).subscribe());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.DefaultChatClient;
import org.springframework.ai.openai.OpenAiTestConfiguration;
import org.springframework.ai.openai.api.tool.MockWeatherService;
import org.springframework.ai.openai.testutils.AbstractIT;
Expand Down
Loading