From 336621b489f5a91ab8285415a38d32e60cc3dc6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 7 Aug 2024 13:15:06 +0200 Subject: [PATCH 1/5] WIP Observability for ChatModel --- .../ai/openai/OpenAiChatModel.java | 137 ++++++++++++------ 1 file changed, 92 insertions(+), 45 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index f798323d82d..1235363d288 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -25,6 +25,8 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import io.micrometer.observation.Observation; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.AssistantMessage; @@ -75,6 +77,7 @@ import io.micrometer.observation.ObservationRegistry; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI} @@ -271,25 +274,48 @@ public ChatResponse call(Prompt prompt) { @Override public Flux stream(Prompt prompt) { - - ChatCompletionRequest request = createRequest(prompt, true); - - Flux completionChunks = this.retryTemplate - .execute(ctx -> this.openAiApi.chatCompletionStream(request, getAdditionalHttpHeaders(prompt))); - - // For chunked responses, only the first chunk contains the choice role. - // The rest of the chunks with same ID share the same role. - ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); - - // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse - // the function call handling logic. - Flux chatResponse = completionChunks.map(this::chunkToChatCompletion) - .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { - try { - @SuppressWarnings("null") - String id = chatCompletion2.id(); - - // @formatter:off + return Flux.deferContextual(contextView -> { + ChatCompletionRequest request = createRequest(prompt, true); + + Flux completionChunks = + this.retryTemplate.execute(ctx -> this.openAiApi.chatCompletionStream( + request, + getAdditionalHttpHeaders(prompt))); + + // For chunked responses, only the first chunk contains the choice role. + // The rest of the chunks with same ID share the same role. + ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); + + final ChatModelObservationContext observationContext = + ChatModelObservationContext.builder() + .prompt(prompt) + .operationMetadata(buildOperationMetadata()) + .requestOptions(buildRequestOptions(request)) + .build(); + + Observation observation = + ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, + DEFAULT_OBSERVATION_CONVENTION, + () -> observationContext, + this.observationRegistry); + + observation + .parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)) + .start(); + + // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse + // the function call handling logic. + Flux chatResponse = + completionChunks.map(this::chunkToChatCompletion) + .switchMap(chatCompletion -> Mono.just(chatCompletion) + .map(chatCompletion2 -> { + try { + @SuppressWarnings("null") String + id = + chatCompletion2.id(); + + // @formatter:off List generations = chatCompletion2.choices().stream().map(choice -> { if (choice.message().role() != null) { roleMap.putIfAbsent(id, choice.message().role().name()); @@ -303,32 +329,53 @@ public Flux stream(Prompt prompt) { }).toList(); // @formatter:on - if (chatCompletion2.usage() != null) { - return new ChatResponse(generations, from(chatCompletion2, null)); - } - else { - return new ChatResponse(generations); - } - } - catch (Exception e) { - logger.error("Error processing chat completion", e); - return new ChatResponse(List.of()); - } - - })); - - return chatResponse.flatMap(response -> { - - if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { - var toolCallConversation = handleToolCalls(prompt, response); - // Recursively call the stream method with the tool call message - // conversation that contains the call responses. - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); - } - else { - return Flux.just(response); - } + if (chatCompletion2.usage() != null) { + return new ChatResponse( + generations, + from(chatCompletion2, + null)); + } + else { + return new ChatResponse( + generations); + } + } + catch (Exception e) { + logger.error( + "Error processing chat completion", + e); + return new ChatResponse( + List.of()); + } + + })); + + return chatResponse.flatMap(response -> { + + if (isToolCall(response, + Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { + var toolCallConversation = handleToolCalls(prompt, response); + // Recursively call the stream method with the tool call message + // conversation that contains the call responses. + return this.stream(new Prompt(toolCallConversation, + prompt.getOptions())); + } + else { + return Flux.just(response); + } + }) + .doOnError(observation::error) + .doFinally(s -> { + // TODO: Consider a custom ObservationContext and + // include additional metadata +// if (s == SignalType.CANCEL) { +// observationContext.setAborted(true); +// } + observation.stop(); + }) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, + observation)); }); } From 62b486f5968342c1efc1f53b2f2cd8125d27c90c Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 7 Aug 2024 14:41:37 +0200 Subject: [PATCH 2/5] minor adjustements + ITs --- .../ai/openai/OpenAiChatModel.java | 128 ++++++++---------- .../chat/OpenAiChatModelObservationIT.java | 59 +++++++- 2 files changed, 113 insertions(+), 74 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 1235363d288..40cf637c40e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -277,45 +277,34 @@ public Flux stream(Prompt prompt) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); - Flux completionChunks = - this.retryTemplate.execute(ctx -> this.openAiApi.chatCompletionStream( - request, - getAdditionalHttpHeaders(prompt))); + Flux completionChunks = this.retryTemplate + .execute(ctx -> this.openAiApi.chatCompletionStream(request, getAdditionalHttpHeaders(prompt))); // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. ConcurrentHashMap roleMap = new ConcurrentHashMap<>(); - final ChatModelObservationContext observationContext = - ChatModelObservationContext.builder() - .prompt(prompt) - .operationMetadata(buildOperationMetadata()) - .requestOptions(buildRequestOptions(request)) - .build(); + final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .operationMetadata(buildOperationMetadata()) + .requestOptions(buildRequestOptions(request)) + .build(); - Observation observation = - ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( - this.observationConvention, - DEFAULT_OBSERVATION_CONVENTION, - () -> observationContext, - this.observationRegistry); + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry); - observation - .parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)) - .start(); + observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse // the function call handling logic. - Flux chatResponse = - completionChunks.map(this::chunkToChatCompletion) - .switchMap(chatCompletion -> Mono.just(chatCompletion) - .map(chatCompletion2 -> { - try { - @SuppressWarnings("null") String - id = - chatCompletion2.id(); - - // @formatter:off + Flux chatResponse = completionChunks.map(this::chunkToChatCompletion) + .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { + try { + @SuppressWarnings("null") + String id = chatCompletion2.id(); + + // @formatter:off List generations = chatCompletion2.choices().stream().map(choice -> { if (choice.message().role() != null) { roleMap.putIfAbsent(id, choice.message().role().name()); @@ -329,53 +318,46 @@ public Flux stream(Prompt prompt) { }).toList(); // @formatter:on - if (chatCompletion2.usage() != null) { - return new ChatResponse( - generations, - from(chatCompletion2, - null)); - } - else { - return new ChatResponse( - generations); - } - } - catch (Exception e) { - logger.error( - "Error processing chat completion", - e); - return new ChatResponse( - List.of()); - } - - })); + if (chatCompletion2.usage() != null) { + return new ChatResponse(generations, from(chatCompletion2, null)); + } + else { + return new ChatResponse(generations); + } + } + catch (Exception e) { + logger.error("Error processing chat completion", e); + return new ChatResponse(List.of()); + } + })); + + // @formatter:off return chatResponse.flatMap(response -> { - if (isToolCall(response, - Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { - var toolCallConversation = handleToolCalls(prompt, response); - // Recursively call the stream method with the tool call message - // conversation that contains the call responses. - return this.stream(new Prompt(toolCallConversation, - prompt.getOptions())); - } - else { - return Flux.just(response); - } - }) - .doOnError(observation::error) - .doFinally(s -> { - // TODO: Consider a custom ObservationContext and - // include additional metadata -// if (s == SignalType.CANCEL) { -// observationContext.setAborted(true); -// } - observation.stop(); - }) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, - observation)); + if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { + var toolCallConversation = handleToolCalls(prompt, response); + // Recursively call the stream method with the tool call message + // conversation that contains the call responses. + return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); + } + else { + return Flux.just(response); + } + }) + .doOnNext(cr -> observationContext.setResponse(cr)) + .doOnError(observation::error) + .doFinally(s -> { + // TODO: Consider a custom ObservationContext and + // include additional metadata + // if (s == SignalType.CANCEL) { + // observationContext.setAborted(true); + // } + observation.stop(); + }) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + // @formatter:on }); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java index 8daf8e8b322..9ef19d2b5c7 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java @@ -18,6 +18,8 @@ import io.micrometer.common.KeyValue; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; +import reactor.core.publisher.Flux; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.metadata.ChatResponseMetadata; @@ -58,7 +60,7 @@ public class OpenAiChatModelObservationIT { OpenAiChatModel chatModel; @Test - void observationForEmbeddingOperation() { + void observationForChatOperation() { var options = OpenAiChatOptions.builder() .withModel(OpenAiApi.ChatModel.GPT_4_O_MINI.getValue()) .withFrequencyPenalty(0f) @@ -108,6 +110,61 @@ void observationForEmbeddingOperation() { .hasBeenStopped(); } + @Test + void observationForStreamingChatOperation() { + var options = OpenAiChatOptions.builder() + .withModel(OpenAiApi.ChatModel.GPT_4_O_MINI.getValue()) + .withFrequencyPenalty(0f) + .withMaxTokens(2048) + .withPresencePenalty(0f) + .withStop(List.of("this-is-the-end")) + .withTemperature(0.7f) + .withTopP(1f) + .build(); + + Prompt prompt = new Prompt("Why does a raven look like a desk?", options); + + Flux chatResponseFlux = chatModel.stream(prompt); + List response = chatResponseFlux.collectList().block(); + assertThat(response).isNotEmpty(); + // assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); + + // ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); + // assertThat(responseMetadata).isNotNull(); + + TestObservationRegistryAssert.assertThat(observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("chat " + OpenAiApi.ChatModel.GPT_4_O_MINI.getValue()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.CHAT.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.OPENAI.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), + OpenAiApi.ChatModel.GPT_4_O_MINI.getValue()) + // .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), + // responseMetadata.getModel()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), "0.0") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(), + "[\"this-is-the-end\"]") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TOP_K.asString(), KeyValue.NONE_VALUE) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0") + // .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_ID.asString(), + // responseMetadata.getId()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_FINISH_REASONS.asString(), "[\"STOP\"]") + // .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + // String.valueOf(responseMetadata.getUsage().getPromptTokens())) + // .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), + // String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + // .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + // String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasBeenStarted() + .hasBeenStopped(); + } + @SpringBootConfiguration static class Config { From 5f3aef56bfbd254ed0fece5740f57501e219ac0b Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 7 Aug 2024 15:31:35 +0200 Subject: [PATCH 3/5] work in progress --- .../ai/openai/OpenAiChatModel.java | 7 +---- .../chat/OpenAiChatModelObservationIT.java | 28 +++++++++++++------ 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 40cf637c40e..6c0bcab2b00 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -318,12 +318,7 @@ public Flux stream(Prompt prompt) { }).toList(); // @formatter:on - if (chatCompletion2.usage() != null) { - return new ChatResponse(generations, from(chatCompletion2, null)); - } - else { - return new ChatResponse(generations); - } + return new ChatResponse(generations, from(chatCompletion2, null)); } catch (Exception e) { logger.error("Error processing chat completion", e); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java index 9ef19d2b5c7..2593f1e49d1 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java @@ -39,6 +39,7 @@ import org.springframework.retry.support.RetryTemplate; import java.util.List; +import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames; @@ -120,17 +121,28 @@ void observationForStreamingChatOperation() { .withStop(List.of("this-is-the-end")) .withTemperature(0.7f) .withTopP(1f) + .withStreamUsage(true) .build(); Prompt prompt = new Prompt("Why does a raven look like a desk?", options); Flux chatResponseFlux = chatModel.stream(prompt); - List response = chatResponseFlux.collectList().block(); - assertThat(response).isNotEmpty(); - // assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); + List responses = chatResponseFlux.collectList().block(); + assertThat(responses).isNotEmpty(); - // ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); - // assertThat(responseMetadata).isNotNull(); + + aggregatedResponse = responses.stream().map(r - r.getResult().getOutput().getContent()).collect(Collectors.joining()); + for (int i = 0; i < responses.size() - 1; i++ ) { + ChatResponse chatResponse = responses.get(i); + System.out.println("I = " + i + " -> " + chatResponse.getResult().getOutput()); + // String content = chatResponse.getResult().getOutput().getContent(); + // assertThat(content).isNotEmpty().as("Response " + i + " has content: " + content); + } + + ChatResponse lastChatResponse = responses.get(responses.size() - 1); + + ChatResponseMetadata responseMetadata = lastChatResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() @@ -142,8 +154,7 @@ void observationForStreamingChatOperation() { .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.OPENAI.value()) .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), OpenAiApi.ChatModel.GPT_4_O_MINI.getValue()) - // .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), - // responseMetadata.getModel()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel()) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0") .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048") .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), "0.0") @@ -152,8 +163,7 @@ void observationForStreamingChatOperation() { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7") .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TOP_K.asString(), KeyValue.NONE_VALUE) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0") - // .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_ID.asString(), - // responseMetadata.getId()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_ID.asString(), responseMetadata.getId()) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_FINISH_REASONS.asString(), "[\"STOP\"]") // .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), // String.valueOf(responseMetadata.getUsage().getPromptTokens())) From 9ef51d093036012f30ea6e34c12ef156741bf1d4 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 8 Aug 2024 09:45:02 +0200 Subject: [PATCH 4/5] feat: Enhance OpenAiChatModel stream with observation and error handling - Integrated Micrometer's for tracing in method. - Enhance to aggregate streaming responses and handle metadata. - Improved error handling and logging for chat response processing. - Updated unit tests to include new observation logic and subscribe to Flux responses. - Refined to validate observations in both normal and streaming chat operations. - Updated test configurations and removed unnecessary imports. --- .../ai/openai/OpenAiChatModel.java | 33 +++--- .../openai/chat/MessageTypeContentTests.java | 8 +- .../chat/OpenAiChatModelObservationIT.java | 68 ++++------- .../chat/OpenAiPaymentTransactionIT.java | 3 +- .../ai/openai/chat/OpenAiRetryTests.java | 2 +- ...enAiChatClientMultipleFunctionCallsIT.java | 1 - .../ai/chat/model/MessageAggregator.java | 108 +++++++++++++++++- 7 files changed, 154 insertions(+), 69 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 6c0bcab2b00..f85f99c1c0a 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -41,6 +41,7 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.observation.ChatModelObservationContext; import org.springframework.ai.chat.observation.ChatModelObservationConvention; @@ -304,19 +305,20 @@ public Flux stream(Prompt prompt) { @SuppressWarnings("null") String id = chatCompletion2.id(); - // @formatter:off - List generations = chatCompletion2.choices().stream().map(choice -> { - if (choice.message().role() != null) { - roleMap.putIfAbsent(id, choice.message().role().name()); - } - Map metadata = Map.of( - "id", chatCompletion2.id(), - "role", roleMap.getOrDefault(id, ""), - "index", choice.index(), - "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); - return buildGeneration(choice, metadata); + List generations = chatCompletion2.choices().stream().map(choice -> {// @formatter:off + + if (choice.message().role() != null) { + roleMap.putIfAbsent(id, choice.message().role().name()); + } + Map metadata = Map.of( + "id", chatCompletion2.id(), + "role", roleMap.getOrDefault(id, ""), + "index", choice.index(), + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""); + + return buildGeneration(choice, metadata); }).toList(); - // @formatter:on + // @formatter:on return new ChatResponse(generations, from(chatCompletion2, null)); } @@ -328,7 +330,7 @@ public Flux stream(Prompt prompt) { })); // @formatter:off - return chatResponse.flatMap(response -> { + Flux flux = chatResponse.flatMap(response -> { if (isToolCall(response, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), OpenAiApi.ChatCompletionFinishReason.STOP.name()))) { @@ -341,7 +343,6 @@ public Flux stream(Prompt prompt) { return Flux.just(response); } }) - .doOnNext(cr -> observationContext.setResponse(cr)) .doOnError(observation::error) .doFinally(s -> { // TODO: Consider a custom ObservationContext and @@ -353,6 +354,10 @@ public Flux stream(Prompt prompt) { }) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on + return new MessageAggregator().aggregate(flux, cr -> { + observationContext.setResponse(cr); + }); + }); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java index 778870c3ad4..8a393c8aafd 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java @@ -104,7 +104,7 @@ public void streamUserMessageSimpleContentType() { when(openAiApi.chatCompletionStream(pomptCaptor.capture(), headersCaptor.capture())).thenReturn(fluxResponse); - chatModel.stream(new Prompt(List.of(new UserMessage("test message")))); + chatModel.stream(new Prompt(List.of(new UserMessage("test message")))).subscribe(); validateStringContent(pomptCaptor.getValue()); assertThat(headersCaptor.getValue()).isEmpty(); @@ -137,8 +137,10 @@ public void streamUserMessageWithMediaType() throws MalformedURLException { when(openAiApi.chatCompletionStream(pomptCaptor.capture(), headersCaptor.capture())).thenReturn(fluxResponse); URL mediaUrl = new URL("http://test"); - chatModel.stream(new Prompt( - List.of(new UserMessage("test message", List.of(new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl)))))); + chatModel + .stream(new Prompt( + List.of(new UserMessage("test message", List.of(new Media(MimeTypeUtils.IMAGE_JPEG, mediaUrl)))))) + .subscribe(); validateComplexContent(pomptCaptor.getValue()); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java index 2593f1e49d1..d0e582adbeb 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java @@ -20,6 +20,7 @@ import io.micrometer.observation.tck.TestObservationRegistryAssert; import reactor.core.publisher.Flux; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.metadata.ChatResponseMetadata; @@ -60,8 +61,14 @@ public class OpenAiChatModelObservationIT { @Autowired OpenAiChatModel chatModel; + @BeforeEach + void beforeEach() { + observationRegistry.clear(); + } + @Test void observationForChatOperation() { + var options = OpenAiChatOptions.builder() .withModel(OpenAiApi.ChatModel.GPT_4_O_MINI.getValue()) .withFrequencyPenalty(0f) @@ -80,35 +87,7 @@ void observationForChatOperation() { ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); - TestObservationRegistryAssert.assertThat(observationRegistry) - .doesNotHaveAnyRemainingCurrentObservation() - .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) - .that() - .hasContextualNameEqualTo("chat " + OpenAiApi.ChatModel.GPT_4_O_MINI.getValue()) - .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), - AiOperationType.CHAT.value()) - .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.OPENAI.value()) - .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), - OpenAiApi.ChatModel.GPT_4_O_MINI.getValue()) - .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), responseMetadata.getModel()) - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0") - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048") - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), "0.0") - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(), - "[\"this-is-the-end\"]") - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7") - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TOP_K.asString(), KeyValue.NONE_VALUE) - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0") - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_ID.asString(), responseMetadata.getId()) - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_FINISH_REASONS.asString(), "[\"STOP\"]") - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getPromptTokens())) - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getGenerationTokens())) - .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), - String.valueOf(responseMetadata.getUsage().getTotalTokens())) - .hasBeenStarted() - .hasBeenStopped(); + validate(responseMetadata); } @Test @@ -127,23 +106,26 @@ void observationForStreamingChatOperation() { Prompt prompt = new Prompt("Why does a raven look like a desk?", options); Flux chatResponseFlux = chatModel.stream(prompt); + List responses = chatResponseFlux.collectList().block(); assertThat(responses).isNotEmpty(); + assertThat(responses).hasSizeGreaterThan(10); - - aggregatedResponse = responses.stream().map(r - r.getResult().getOutput().getContent()).collect(Collectors.joining()); - for (int i = 0; i < responses.size() - 1; i++ ) { - ChatResponse chatResponse = responses.get(i); - System.out.println("I = " + i + " -> " + chatResponse.getResult().getOutput()); - // String content = chatResponse.getResult().getOutput().getContent(); - // assertThat(content).isNotEmpty().as("Response " + i + " has content: " + content); - } + String aggregatedResponse = responses.subList(0, responses.size() - 1) + .stream() + .map(r -> r.getResult().getOutput().getContent()) + .collect(Collectors.joining()); + assertThat(aggregatedResponse).isNotEmpty(); ChatResponse lastChatResponse = responses.get(responses.size() - 1); ChatResponseMetadata responseMetadata = lastChatResponse.getMetadata(); assertThat(responseMetadata).isNotNull(); + validate(responseMetadata); + } + + private void validate(ChatResponseMetadata responseMetadata) { TestObservationRegistryAssert.assertThat(observationRegistry) .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) @@ -165,12 +147,12 @@ void observationForStreamingChatOperation() { .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TOP_P.asString(), "1.0") .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_ID.asString(), responseMetadata.getId()) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.RESPONSE_FINISH_REASONS.asString(), "[\"STOP\"]") - // .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), - // String.valueOf(responseMetadata.getUsage().getPromptTokens())) - // .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), - // String.valueOf(responseMetadata.getUsage().getGenerationTokens())) - // .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), - // String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getPromptTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_OUTPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getGenerationTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getTotalTokens())) .hasBeenStarted() .hasBeenStopped(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java index 66e91c0ede4..46326fc8b12 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java @@ -53,8 +53,7 @@ * @author Christian Tzolov */ @SpringBootTest -@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") -@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") public class OpenAiPaymentTransactionIT { private final static Logger logger = LoggerFactory.getLogger(OpenAiPaymentTransactionIT.class); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index c7d510af5e8..c48ade0c682 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -187,7 +187,7 @@ public void openAiChatStreamTransientError() { public void openAiChatStreamNonTransientError() { when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any())) .thenThrow(new RuntimeException("Non Transient Error")); - assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text"))); + assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")).subscribe()); } @Test diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java index 63e55cc2acd..812ba3fc08c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java @@ -27,7 +27,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.DefaultChatClient; import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.api.tool.MockWeatherService; import org.springframework.ai.openai.testutils.AbstractIT; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index b4f688f8773..ad02dddf291 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -24,6 +24,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.EmptyRateLimit; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.RateLimit; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.util.StringUtils; + import reactor.core.publisher.Flux; /** @@ -43,13 +52,37 @@ public Flux aggregate(Flux fluxChatResponse, AtomicReference stringBufferRef = new AtomicReference<>(new StringBuilder()); AtomicReference> mapRef = new AtomicReference<>(); + AtomicReference generationMetadataRef = new AtomicReference<>( + ChatGenerationMetadata.NULL); + + AtomicReference metadataUsagePromptTokensRef = new AtomicReference<>(0L); + AtomicReference metadataUsageGenerationTokensRef = new AtomicReference<>(0L); + AtomicReference metadataUsageTotalTokensRef = new AtomicReference<>(0L); + + AtomicReference metadataPromptMetadataRef = new AtomicReference<>(PromptMetadata.empty()); + AtomicReference metadataRateLimitRef = new AtomicReference<>(new EmptyRateLimit()); + + AtomicReference metadataIdRef = new AtomicReference<>(""); + AtomicReference metadataModelRef = new AtomicReference<>(""); + return fluxChatResponse.doOnSubscribe(subscription -> { - // logger.info("Aggregation Subscribe:" + subscription); stringBufferRef.set(new StringBuilder()); mapRef.set(new HashMap<>()); + metadataIdRef.set(""); + metadataModelRef.set(""); + metadataUsagePromptTokensRef.set(0L); + metadataUsageGenerationTokensRef.set(0L); + metadataUsageTotalTokensRef.set(0L); + metadataPromptMetadataRef.set(PromptMetadata.empty()); + metadataRateLimitRef.set(new EmptyRateLimit()); + }).doOnNext(chatResponse -> { - // logger.info("Aggregation Next:" + chatResponse); + if (chatResponse.getResult() != null) { + if (chatResponse.getResult().getMetadata() != null + && chatResponse.getResult().getMetadata() != ChatGenerationMetadata.NULL) { + generationMetadataRef.set(chatResponse.getResult().getMetadata()); + } if (chatResponse.getResult().getOutput().getContent() != null) { stringBufferRef.get().append(chatResponse.getResult().getOutput().getContent()); } @@ -57,15 +90,80 @@ public Flux aggregate(Flux fluxChatResponse, mapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata()); } } + if (chatResponse.getMetadata() != null) { + if (chatResponse.getMetadata().getUsage() != null) { + Usage usage = chatResponse.getMetadata().getUsage(); + metadataUsagePromptTokensRef.set( + usage.getPromptTokens() > 0 ? usage.getPromptTokens() : metadataUsagePromptTokensRef.get()); + metadataUsageGenerationTokensRef.set(usage.getGenerationTokens() > 0 ? usage.getGenerationTokens() + : metadataUsageGenerationTokensRef.get()); + metadataUsageTotalTokensRef + .set(usage.getTotalTokens() > 0 ? usage.getTotalTokens() : metadataUsageTotalTokensRef.get()); + } + if (chatResponse.getMetadata().getPromptMetadata() != null + && chatResponse.getMetadata().getPromptMetadata().iterator().hasNext()) { + metadataPromptMetadataRef.set(chatResponse.getMetadata().getPromptMetadata()); + } + if (chatResponse.getMetadata().getRateLimit() != null + && !(metadataRateLimitRef.get() instanceof EmptyRateLimit)) { + metadataRateLimitRef.set(chatResponse.getMetadata().getRateLimit()); + } + if (StringUtils.hasText(chatResponse.getMetadata().getId())) { + metadataIdRef.set(chatResponse.getMetadata().getId()); + } + if (StringUtils.hasText(chatResponse.getMetadata().getModel())) { + metadataModelRef.set(chatResponse.getMetadata().getModel()); + } + } }).doOnComplete(() -> { - // logger.debug("Aggregation Complete"); - onAggregationComplete - .accept(new ChatResponse(List.of(new Generation(stringBufferRef.get().toString(), mapRef.get())))); + + var usage = new DefaultUsage(metadataUsagePromptTokensRef.get(), metadataUsageGenerationTokensRef.get(), + metadataUsageTotalTokensRef.get()); + + var chatResponseMetadata = ChatResponseMetadata.builder() + .withId(metadataIdRef.get()) + .withModel(metadataModelRef.get()) + .withRateLimit(metadataRateLimitRef.get()) + .withUsage(usage) + .withPromptMetadata(metadataPromptMetadataRef.get()) + .build(); + + onAggregationComplete.accept(new ChatResponse( + List.of(new Generation(new AssistantMessage(stringBufferRef.get().toString(), mapRef.get()), + generationMetadataRef.get())), + chatResponseMetadata)); + stringBufferRef.set(new StringBuilder()); mapRef.set(new HashMap<>()); + metadataIdRef.set(""); + metadataModelRef.set(""); + metadataUsagePromptTokensRef.set(0L); + metadataUsageGenerationTokensRef.set(0L); + metadataUsageTotalTokensRef.set(0L); + metadataPromptMetadataRef.set(PromptMetadata.empty()); + metadataRateLimitRef.set(new EmptyRateLimit()); + }).doOnError(e -> { logger.error("Aggregation Error", e); }); } + public record DefaultUsage(long promptTokens, long generationTokens, long totalTokens) implements Usage { + + @Override + public Long getPromptTokens() { + return promptTokens(); + } + + @Override + public Long getGenerationTokens() { + return generationTokens(); + } + + @Override + public Long getTotalTokens() { + return totalTokens(); + } + } + } \ No newline at end of file From 7cfef1f89bedcd60dfc4571694660986bccfdec4 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 8 Aug 2024 11:33:44 +0200 Subject: [PATCH 5/5] Disable retry for streaming --- .../ai/openai/OpenAiChatModel.java | 14 +++++----- .../ai/openai/chat/OpenAiRetryTests.java | 3 +++ .../ai/chat/model/MessageAggregator.java | 26 ++++++++++--------- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index f85f99c1c0a..b4d479a91e7 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -25,8 +25,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; -import io.micrometer.observation.Observation; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.AssistantMessage; @@ -75,10 +73,11 @@ import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; +import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.SignalType; /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI} @@ -278,8 +277,8 @@ public Flux stream(Prompt prompt) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); - Flux completionChunks = this.retryTemplate - .execute(ctx -> this.openAiApi.chatCompletionStream(request, getAdditionalHttpHeaders(prompt))); + Flux completionChunks = this.openAiApi.chatCompletionStream(request, + getAdditionalHttpHeaders(prompt)); // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. @@ -354,8 +353,9 @@ public Flux stream(Prompt prompt) { }) .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on - return new MessageAggregator().aggregate(flux, cr -> { - observationContext.setResponse(cr); + + return new MessageAggregator().aggregate(flux, mergedChatResponse -> { + observationContext.setResponse(mergedChatResponse); }); }); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index c48ade0c682..0f486d4c5bf 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -19,6 +19,7 @@ import java.util.Optional; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -163,6 +164,7 @@ public void openAiChatNonTransientError() { } @Test + @Disabled("Currently stream() does not implmement retry") public void openAiChatStreamTransientError() { var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0, @@ -184,6 +186,7 @@ public void openAiChatStreamTransientError() { } @Test + @Disabled("Currently stream() does not implmement retry") public void openAiChatStreamNonTransientError() { when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class), any())) .thenThrow(new RuntimeException("Non Transient Error")); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java index ad02dddf291..6aef10ed777 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java @@ -49,12 +49,15 @@ public class MessageAggregator { public Flux aggregate(Flux fluxChatResponse, Consumer onAggregationComplete) { - AtomicReference stringBufferRef = new AtomicReference<>(new StringBuilder()); - AtomicReference> mapRef = new AtomicReference<>(); + // Assistant Message + AtomicReference messageTextContentRef = new AtomicReference<>(new StringBuilder()); + AtomicReference> messageMetadataMapRef = new AtomicReference<>(); + // ChatGeneration Metadata AtomicReference generationMetadataRef = new AtomicReference<>( ChatGenerationMetadata.NULL); + // Usage AtomicReference metadataUsagePromptTokensRef = new AtomicReference<>(0L); AtomicReference metadataUsageGenerationTokensRef = new AtomicReference<>(0L); AtomicReference metadataUsageTotalTokensRef = new AtomicReference<>(0L); @@ -66,8 +69,8 @@ public Flux aggregate(Flux fluxChatResponse, AtomicReference metadataModelRef = new AtomicReference<>(""); return fluxChatResponse.doOnSubscribe(subscription -> { - stringBufferRef.set(new StringBuilder()); - mapRef.set(new HashMap<>()); + messageTextContentRef.set(new StringBuilder()); + messageMetadataMapRef.set(new HashMap<>()); metadataIdRef.set(""); metadataModelRef.set(""); metadataUsagePromptTokensRef.set(0L); @@ -84,10 +87,10 @@ public Flux aggregate(Flux fluxChatResponse, generationMetadataRef.set(chatResponse.getResult().getMetadata()); } if (chatResponse.getResult().getOutput().getContent() != null) { - stringBufferRef.get().append(chatResponse.getResult().getOutput().getContent()); + messageTextContentRef.get().append(chatResponse.getResult().getOutput().getContent()); } if (chatResponse.getResult().getOutput().getMetadata() != null) { - mapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata()); + messageMetadataMapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata()); } } if (chatResponse.getMetadata() != null) { @@ -128,13 +131,12 @@ public Flux aggregate(Flux fluxChatResponse, .withPromptMetadata(metadataPromptMetadataRef.get()) .build(); - onAggregationComplete.accept(new ChatResponse( - List.of(new Generation(new AssistantMessage(stringBufferRef.get().toString(), mapRef.get()), - generationMetadataRef.get())), - chatResponseMetadata)); + onAggregationComplete.accept(new ChatResponse(List.of(new Generation( + new AssistantMessage(messageTextContentRef.get().toString(), messageMetadataMapRef.get()), + generationMetadataRef.get())), chatResponseMetadata)); - stringBufferRef.set(new StringBuilder()); - mapRef.set(new HashMap<>()); + messageTextContentRef.set(new StringBuilder()); + messageMetadataMapRef.set(new HashMap<>()); metadataIdRef.set(""); metadataModelRef.set(""); metadataUsagePromptTokensRef.set(0L);