Skip to content

Commit d42cd60

Browse files
committed
Fix: AbstractBedrockApi Streaming reusing of Flux Sink
Resolves: #474
1 parent 34c4703 commit d42cd60

File tree

7 files changed

+169
-16
lines changed

7 files changed

+169
-16
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@
1818

1919
import java.io.UncheckedIOException;
2020
import java.nio.charset.StandardCharsets;
21+
import java.time.Duration;
2122

2223
import com.fasterxml.jackson.annotation.JsonInclude;
2324
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2425
import com.fasterxml.jackson.annotation.JsonProperty;
2526
import com.fasterxml.jackson.core.JsonProcessingException;
2627
import com.fasterxml.jackson.databind.DeserializationFeature;
2728
import com.fasterxml.jackson.databind.ObjectMapper;
28-
import org.apache.commons.logging.Log;
29-
import org.apache.commons.logging.LogFactory;
29+
import org.slf4j.Logger;
30+
import org.slf4j.LoggerFactory;
3031
import reactor.core.publisher.Flux;
3132
import reactor.core.publisher.Sinks;
33+
import reactor.core.publisher.Sinks.EmitFailureHandler;
34+
import reactor.core.publisher.Sinks.EmitResult;
3235
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
3336
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
3437
import software.amazon.awssdk.core.SdkBytes;
@@ -60,15 +63,14 @@
6063
*/
6164
public abstract class AbstractBedrockApi<I, O, SO> {
6265

63-
private static final Log logger = LogFactory.getLog(AbstractBedrockApi.class);
66+
private static final Logger logger = LoggerFactory.getLogger(AbstractBedrockApi.class);
6467

6568
private final String modelId;
6669
private final ObjectMapper objectMapper;
6770
private final AwsCredentialsProvider credentialsProvider;
6871
private final String region;
6972
private final BedrockRuntimeClient client;
7073
private final BedrockRuntimeAsyncClient clientStreaming;
71-
private final Sinks.Many<SO> eventSink;
7274

7375
/**
7476
* Create a new AbstractBedrockApi instance using default credentials provider and object mapper.
@@ -96,8 +98,6 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv
9698
this.credentialsProvider = credentialsProvider;
9799
this.region = region;
98100

99-
this.eventSink = Sinks.many().unicast().onBackpressureError();
100-
101101
this.client = BedrockRuntimeClient.builder()
102102
.region(Region.of(this.region))
103103
.credentialsProvider(this.credentialsProvider)
@@ -220,13 +220,16 @@ protected O internalInvocation(I request, Class<O> clazz) {
220220
*/
221221
protected Flux<SO> internalInvocationStream(I request, Class<SO> clazz) {
222222

223+
// final Sinks.Many<SO> eventSink = Sinks.many().unicast().onBackpressureError();
224+
final Sinks.Many<SO> eventSink = Sinks.many().multicast().onBackpressureBuffer();
225+
223226
SdkBytes body;
224227
try {
225228
body = SdkBytes.fromUtf8String(this.objectMapper.writeValueAsString(request));
226229
}
227230
catch (JsonProcessingException e) {
228-
this.eventSink.tryEmitError(e);
229-
return this.eventSink.asFlux();
231+
eventSink.tryEmitError(e);
232+
return eventSink.asFlux();
230233
}
231234

232235
InvokeModelWithResponseStreamRequest invokeRequest = InvokeModelWithResponseStreamRequest.builder()
@@ -240,29 +243,35 @@ protected Flux<SO> internalInvocationStream(I request, Class<SO> clazz) {
240243
try {
241244
logger.debug("Received chunk: " + chunk.bytes().asString(StandardCharsets.UTF_8));
242245
SO response = this.objectMapper.readValue(chunk.bytes().asByteArray(), clazz);
243-
this.eventSink.tryEmitNext(response);
246+
eventSink.tryEmitNext(response);
244247
}
245248
catch (Exception e) {
246249
logger.error("Failed to unmarshall", e);
247-
this.eventSink.tryEmitError(e);
250+
eventSink.tryEmitError(e);
248251
}
249252
})
250253
.onDefault((event) -> {
251254
logger.error("Unknown or unhandled event: " + event.toString());
252-
this.eventSink.tryEmitError(new Throwable("Unknown or unhandled event: " + event.toString()));
255+
eventSink.tryEmitError(new Throwable("Unknown or unhandled event: " + event.toString()));
253256
})
254257
.build();
255258

256259
InvokeModelWithResponseStreamResponseHandler responseHandler = InvokeModelWithResponseStreamResponseHandler
257260
.builder()
258261
.onComplete(
259262
() -> {
260-
this.eventSink.tryEmitComplete();
263+
EmitResult emitResult = eventSink.tryEmitComplete();
264+
while(!emitResult.isSuccess()){
265+
System.out.println("Emitting complete:" + emitResult);
266+
emitResult = eventSink.tryEmitComplete();
267+
};
268+
eventSink.emitComplete(EmitFailureHandler.busyLooping(Duration.ofSeconds(3)));
269+
// EmitResult emitResult = eventSink.tryEmitComplete();
261270
logger.debug("\nCompleted streaming response.");
262271
})
263272
.onError((error) -> {
264273
logger.error("\n\nError streaming response: " + error.getMessage());
265-
this.eventSink.tryEmitError(error);
274+
eventSink.tryEmitError(error);
266275
})
267276
.onEventStream((stream) -> {
268277
stream.subscribe(
@@ -274,7 +283,7 @@ protected Flux<SO> internalInvocationStream(I request, Class<SO> clazz) {
274283

275284
this.clientStreaming.invokeModelWithResponseStream(invokeRequest, responseHandler);
276285

277-
return this.eventSink.asFlux();
286+
return eventSink.asFlux();
278287
}
279288
}
280289
// @formatter:on

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2626
import org.slf4j.Logger;
2727
import org.slf4j.LoggerFactory;
28+
import reactor.core.publisher.Flux;
2829

2930
import org.springframework.ai.chat.ChatResponse;
3031
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -64,6 +65,33 @@ class BedrockAnthropicChatClientIT {
6465
@Value("classpath:/prompts/system-message.st")
6566
private Resource systemResource;
6667

68+
@Test
69+
void multipleStreamAttempts() {
70+
71+
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
72+
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
73+
74+
String joke1 = joke1Stream.collectList()
75+
.block()
76+
.stream()
77+
.map(ChatResponse::getResults)
78+
.flatMap(List::stream)
79+
.map(Generation::getOutput)
80+
.map(AssistantMessage::getContent)
81+
.collect(Collectors.joining());
82+
String joke2 = joke2Stream.collectList()
83+
.block()
84+
.stream()
85+
.map(ChatResponse::getResults)
86+
.flatMap(List::stream)
87+
.map(Generation::getOutput)
88+
.map(AssistantMessage::getContent)
89+
.collect(Collectors.joining());
90+
91+
assertThat(joke1).isNotBlank();
92+
assertThat(joke2).isNotBlank();
93+
}
94+
6795
@Test
6896
void roleTest() {
6997
UserMessage userMessage = new UserMessage(
@@ -176,7 +204,7 @@ public static class TestConfiguration {
176204
@Bean
177205
public AnthropicChatBedrockApi anthropicApi() {
178206
return new AnthropicChatBedrockApi(AnthropicChatBedrockApi.AnthropicChatModel.CLAUDE_V2.id(),
179-
EnvironmentVariableCredentialsProvider.create(), Region.EU_CENTRAL_1.id(), new ObjectMapper());
207+
EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper());
180208
}
181209

182210
@Bean

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicCreateRequestTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
public class BedrockAnthropicCreateRequestTests {
3333

3434
private AnthropicChatBedrockApi anthropicChatApi = new AnthropicChatBedrockApi(AnthropicChatModel.CLAUDE_V2.id(),
35-
Region.EU_CENTRAL_1.id());
35+
Region.US_EAST_1.id());
3636

3737
@Test
3838
public void createRequestWithChatOptions() {

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClientIT.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2121
import org.slf4j.Logger;
2222
import org.slf4j.LoggerFactory;
23+
import reactor.core.publisher.Flux;
24+
2325
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi;
2426
import org.springframework.ai.chat.ChatResponse;
2527
import org.springframework.ai.chat.Generation;
@@ -67,6 +69,33 @@ class BedrockAnthropic3ChatClientIT {
6769
@Value("classpath:/prompts/system-message.st")
6870
private Resource systemResource;
6971

72+
@Test
73+
void multipleStreamAttempts() {
74+
75+
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
76+
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
77+
78+
String joke1 = joke1Stream.collectList()
79+
.block()
80+
.stream()
81+
.map(ChatResponse::getResults)
82+
.flatMap(List::stream)
83+
.map(Generation::getOutput)
84+
.map(AssistantMessage::getContent)
85+
.collect(Collectors.joining());
86+
String joke2 = joke2Stream.collectList()
87+
.block()
88+
.stream()
89+
.map(ChatResponse::getResults)
90+
.flatMap(List::stream)
91+
.map(Generation::getOutput)
92+
.map(AssistantMessage::getContent)
93+
.collect(Collectors.joining());
94+
95+
assertThat(joke1).isNotBlank();
96+
assertThat(joke2).isNotBlank();
97+
}
98+
7099
@Test
71100
void roleTest() {
72101
UserMessage userMessage = new UserMessage(

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import com.fasterxml.jackson.databind.ObjectMapper;
2424
import org.junit.jupiter.api.Test;
2525
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
26+
import reactor.core.publisher.Flux;
27+
2628
import org.springframework.ai.chat.messages.AssistantMessage;
2729
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
2830
import software.amazon.awssdk.regions.Region;
@@ -60,6 +62,33 @@ class BedrockCohereChatClientIT {
6062
@Value("classpath:/prompts/system-message.st")
6163
private Resource systemResource;
6264

65+
@Test
66+
void multipleStreamAttempts() {
67+
68+
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
69+
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
70+
71+
String joke1 = joke1Stream.collectList()
72+
.block()
73+
.stream()
74+
.map(ChatResponse::getResults)
75+
.flatMap(List::stream)
76+
.map(Generation::getOutput)
77+
.map(AssistantMessage::getContent)
78+
.collect(Collectors.joining());
79+
String joke2 = joke2Stream.collectList()
80+
.block()
81+
.stream()
82+
.map(ChatResponse::getResults)
83+
.flatMap(List::stream)
84+
.map(Generation::getOutput)
85+
.map(AssistantMessage::getContent)
86+
.collect(Collectors.joining());
87+
88+
assertThat(joke1).isNotBlank();
89+
assertThat(joke2).isNotBlank();
90+
}
91+
6392
@Test
6493
void roleTest() {
6594
String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import org.junit.jupiter.api.Disabled;
2525
import org.junit.jupiter.api.Test;
2626
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
27+
import reactor.core.publisher.Flux;
28+
2729
import org.springframework.ai.chat.ChatResponse;
2830
import org.springframework.ai.chat.messages.AssistantMessage;
2931
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
@@ -61,6 +63,33 @@ class BedrockLlama2ChatClientIT {
6163
@Value("classpath:/prompts/system-message.st")
6264
private Resource systemResource;
6365

66+
@Test
67+
void multipleStreamAttempts() {
68+
69+
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
70+
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
71+
72+
String joke1 = joke1Stream.collectList()
73+
.block()
74+
.stream()
75+
.map(ChatResponse::getResults)
76+
.flatMap(List::stream)
77+
.map(Generation::getOutput)
78+
.map(AssistantMessage::getContent)
79+
.collect(Collectors.joining());
80+
String joke2 = joke2Stream.collectList()
81+
.block()
82+
.stream()
83+
.map(ChatResponse::getResults)
84+
.flatMap(List::stream)
85+
.map(Generation::getOutput)
86+
.map(AssistantMessage::getContent)
87+
.collect(Collectors.joining());
88+
89+
assertThat(joke1).isNotBlank();
90+
assertThat(joke2).isNotBlank();
91+
}
92+
6493
@Test
6594
void roleTest() {
6695
UserMessage userMessage = new UserMessage(

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import org.junit.jupiter.api.Disabled;
2525
import org.junit.jupiter.api.Test;
2626
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
27+
import reactor.core.publisher.Flux;
28+
2729
import org.springframework.ai.chat.ChatResponse;
2830
import org.springframework.ai.chat.messages.AssistantMessage;
2931
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
@@ -61,6 +63,33 @@ class BedrockTitanChatClientIT {
6163
@Value("classpath:/prompts/system-message.st")
6264
private Resource systemResource;
6365

66+
@Test
67+
void multipleStreamAttempts() {
68+
69+
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
70+
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
71+
72+
String joke1 = joke1Stream.collectList()
73+
.block()
74+
.stream()
75+
.map(ChatResponse::getResults)
76+
.flatMap(List::stream)
77+
.map(Generation::getOutput)
78+
.map(AssistantMessage::getContent)
79+
.collect(Collectors.joining());
80+
String joke2 = joke2Stream.collectList()
81+
.block()
82+
.stream()
83+
.map(ChatResponse::getResults)
84+
.flatMap(List::stream)
85+
.map(Generation::getOutput)
86+
.map(AssistantMessage::getContent)
87+
.collect(Collectors.joining());
88+
89+
assertThat(joke1).isNotBlank();
90+
assertThat(joke2).isNotBlank();
91+
}
92+
6493
@Test
6594
void roleTest() {
6695
String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";

0 commit comments

Comments
 (0)