-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Bug description
In our scenario, we need to perform a security check for every 2 streaming tokens using buffer(2) followed by concatMap to call an external validation API.
When using OpenAiChatModel with its internal switchMap, some tokens or chunks are dropped unexpectedly.
The issue appears to be caused by switchMap canceling previous inner streams prematurely when new tokens arrive, which conflicts with our buffered, sequential processing logic.
I believe the semantics of switchMap may not be ideal here; using concatMap or flatMapSequential might be more appropriate.
This issue might be related to #876
Environment
Spring AI version:1.1.2
Java version: OpenJDK 21
Operating System: Windows11
Steps to reproduce
- Run the provided unit tests below.
- Although we use the internal LLM endpoint, the same issue is expected to occur with other model endpoints, as it is caused by the internal switchMap behavior.
Expected behavior
All tokens should be emitted and processed in order, allowing buffered calls to the security check API without losing any chunks.
Minimal Complete Reproducible example
@Test
void testSwitchMapMissingPartialResult() throws InterruptedException {
// Initialize OpenAI API client
OpenAiApi openAiApi = OpenAiApi.builder()
.apiKey("YOUR_API_KEY")
.restClientBuilder(RestClient.builder())
.webClientBuilder(WebClient.builder())
.build();
// Build OpenAiChatModel with default options
OpenAiChatModel chatModel = OpenAiChatModel.builder()
.openAiApi(openAiApi)
.defaultOptions(OpenAiChatOptions.builder().build())
.build();
// Prepare a user message requesting integers from 1 to 300
UserMessage userMessage = new UserMessage(
"List all integers from 1 to 300, separated by spaces, with no other text.");
Prompt prompt = new Prompt(List.of(userMessage));
// Start streaming chat response
Flux<ChatResponse> chatResponseFlux = chatModel.stream(prompt);
StringBuilder answer = new StringBuilder();
CountDownLatch latch = new CountDownLatch(1);
chatResponseFlux
// Append each received token to answer
.doOnNext(chatResponse -> {
String responseContent = chatResponse.getResult().getOutput().getText();
answer.append(responseContent);
})
// Buffer every 2 tokens
.buffer(2)
// Simulate external safety check for each buffer (50ms delay)
.concatMap(bufferedList -> Mono.delay(Duration.ofMillis(50)).thenMany(Flux.fromIterable(bufferedList)))
// Release latch when stream completes
.doFinally(signalType -> latch.countDown())
.subscribe();
latch.await();
// Verify that all integers from 1 to 300 are present in the answer
IntStream.rangeClosed(1, 300).forEach(n -> {
assertThat(answer).contains(String.valueOf(n));
});
}
@Test
void testSwitchMapMissingPartialResultWithMockResponse() throws InterruptedException {
// Simpler test without real model dependency
CountDownLatch latch = new CountDownLatch(1);
StringBuilder stringBuilder = new StringBuilder();
// Mock streaming response: integers 1 to 10, emitting one every 30ms
Flux<Integer> mockChatResponseFlux = Flux.range(1, 10)
.delayElements(Duration.ofMillis(30))
// Using switchMap to simulate OpenAiChatModel's internal behavior
.switchMap(Flux::just);
mockChatResponseFlux
// Append each emitted integer to stringBuilder
.doOnNext(stringBuilder::append)
// Buffer every 2 elements
.buffer(2)
// Simulate external safety check taking 50ms, then emit buffered items sequentially
.concatMap(bufferedList -> Mono.delay(Duration.ofMillis(50)).thenMany(Flux.fromIterable(bufferedList)))
// Release latch when processing completes
.doFinally(signal -> latch.countDown())
.subscribe();
// Wait for completion
latch.await();
// Assert that all integers 1-10 are received in order
assertThat(stringBuilder.toString()).isEqualTo("12345678910");
}