Skip to content

Commit 17a0173

Browse files
committed
fix: Improve error resilience of the Bedrock stream handling
The changes include: - Refactor the emitting of next, error, and complete events in the Bedrock stream handling to use a default EmitFailureHandler that retries for 10 seconds before failing. This helps improve the error resilience of the stream processing. - Disable a couple of integration tests related to the COHERE_COMMAND_V14 model, as that model version is no longer supported. - Adjust some configuration options in the Jurassic2 chat model integration test. Also resolves #1679
1 parent 7895875 commit 17a0173

File tree

6 files changed

+25
-28
lines changed

6 files changed

+25
-28
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import reactor.core.publisher.Mono;
3737
import reactor.core.publisher.Sinks;
3838
import reactor.core.publisher.Sinks.EmitFailureHandler;
39-
import reactor.core.publisher.Sinks.EmitResult;
4039
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
4140
import software.amazon.awssdk.core.SdkBytes;
4241
import software.amazon.awssdk.core.document.Document;
@@ -524,6 +523,9 @@ public Flux<ChatResponse> stream(Prompt prompt) {
524523
});
525524
}
526525

526+
public static final EmitFailureHandler DEFAULT_EMIT_FAILURE_HANDLER = EmitFailureHandler
527+
.busyLooping(Duration.ofSeconds(10));
528+
527529
/**
528530
* Invoke the model and return the response stream.
529531
*
@@ -541,26 +543,19 @@ public Flux<ConverseStreamOutput> converseStream(ConverseStreamRequest converseS
541543
ConverseStreamResponseHandler.Visitor visitor = ConverseStreamResponseHandler.Visitor.builder()
542544
.onDefault(output -> {
543545
logger.debug("Received converse stream output:{}", output);
544-
eventSink.tryEmitNext(output);
546+
eventSink.emitNext(output, DEFAULT_EMIT_FAILURE_HANDLER);
545547
})
546548
.build();
547549

548550
ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder()
549551
.onEventStream(stream -> stream.subscribe(e -> e.accept(visitor)))
550552
.onComplete(() -> {
551-
EmitResult emitResult = eventSink.tryEmitComplete();
552-
553-
while (!emitResult.isSuccess()) {
554-
logger.info("Emitting complete:{}", emitResult);
555-
emitResult = eventSink.tryEmitComplete();
556-
}
557-
558-
eventSink.emitComplete(EmitFailureHandler.busyLooping(Duration.ofSeconds(3)));
553+
eventSink.emitComplete(DEFAULT_EMIT_FAILURE_HANDLER);
559554
logger.info("Completed streaming response.");
560555
})
561556
.onError(error -> {
562557
logger.error("Error handling Bedrock converse stream response", error);
563-
eventSink.tryEmitError(error);
558+
eventSink.emitError(error, DEFAULT_EMIT_FAILURE_HANDLER);
564559
})
565560
.build();
566561

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import reactor.core.publisher.Flux;
3333
import reactor.core.publisher.Sinks;
3434
import reactor.core.publisher.Sinks.EmitFailureHandler;
35-
import reactor.core.publisher.Sinks.EmitResult;
3635
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
3736
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
3837
import software.amazon.awssdk.core.SdkBytes;
@@ -70,6 +69,10 @@ public abstract class AbstractBedrockApi<I, O, SO> {
7069

7170
private static final Logger logger = LoggerFactory.getLogger(AbstractBedrockApi.class);
7271

72+
public static final EmitFailureHandler DEFAULT_EMIT_FAILURE_HANDLER = EmitFailureHandler
73+
.busyLooping(Duration.ofSeconds(10));
74+
75+
7376
private final String modelId;
7477
private final ObjectMapper objectMapper;
7578
private final Region region;
@@ -264,7 +267,7 @@ protected Flux<SO> internalInvocationStream(I request, Class<SO> clazz) {
264267
body = SdkBytes.fromUtf8String(this.objectMapper.writeValueAsString(request));
265268
}
266269
catch (JsonProcessingException e) {
267-
eventSink.tryEmitError(e);
270+
eventSink.emitError(e, DEFAULT_EMIT_FAILURE_HANDLER);
268271
return eventSink.asFlux();
269272
}
270273

@@ -279,35 +282,29 @@ protected Flux<SO> internalInvocationStream(I request, Class<SO> clazz) {
279282
try {
280283
logger.debug("Received chunk: " + chunk.bytes().asString(StandardCharsets.UTF_8));
281284
SO response = this.objectMapper.readValue(chunk.bytes().asByteArray(), clazz);
282-
eventSink.tryEmitNext(response);
285+
eventSink.emitNext(response, DEFAULT_EMIT_FAILURE_HANDLER);
283286
}
284287
catch (Exception e) {
285288
logger.error("Failed to unmarshall", e);
286-
eventSink.tryEmitError(e);
289+
eventSink.emitError(e, DEFAULT_EMIT_FAILURE_HANDLER);
287290
}
288291
})
289292
.onDefault(event -> {
290293
logger.error("Unknown or unhandled event: " + event.toString());
291-
eventSink.tryEmitError(new Throwable("Unknown or unhandled event: " + event.toString()));
294+
eventSink.emitError(new Throwable("Unknown or unhandled event: " + event.toString()),DEFAULT_EMIT_FAILURE_HANDLER);
292295
})
293296
.build();
294297

295298
InvokeModelWithResponseStreamResponseHandler responseHandler = InvokeModelWithResponseStreamResponseHandler
296299
.builder()
297300
.onComplete(
298301
() -> {
299-
EmitResult emitResult = eventSink.tryEmitComplete();
300-
while (!emitResult.isSuccess()) {
301-
System.out.println("Emitting complete:" + emitResult);
302-
emitResult = eventSink.tryEmitComplete();
303-
}
304-
eventSink.emitComplete(EmitFailureHandler.busyLooping(Duration.ofSeconds(3)));
305-
// EmitResult emitResult = eventSink.tryEmitComplete();
306-
logger.debug("\nCompleted streaming response.");
302+
eventSink.emitComplete(DEFAULT_EMIT_FAILURE_HANDLER);
303+
logger.info("Completed streaming response.");
307304
})
308305
.onError(error -> {
309306
logger.error("\n\nError streaming response: " + error.getMessage());
310-
eventSink.tryEmitError(error);
307+
eventSink.emitError(error, DEFAULT_EMIT_FAILURE_HANDLER);
311308
})
312309
.onEventStream(stream -> stream.subscribe(
313310
(ResponseStream e) -> e.accept(visitor)))

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.stream.Collectors;
2424

2525
import com.fasterxml.jackson.databind.ObjectMapper;
26+
import org.junit.jupiter.api.Disabled;
2627
import org.junit.jupiter.api.Test;
2728
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2829
import reactor.core.publisher.Flux;
@@ -55,6 +56,7 @@
5556
@SpringBootTest
5657
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
5758
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
59+
@Disabled("COHERE_COMMAND_V14 is not supported anymore")
5860
class BedrockCohereChatModelIT {
5961

6062
@Autowired

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.time.Duration;
2020
import java.util.List;
2121

22+
import org.junit.jupiter.api.Disabled;
2223
import org.junit.jupiter.api.Test;
2324
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2425
import reactor.core.publisher.Flux;
@@ -71,6 +72,7 @@ public void requestBuilder() {
7172
}
7273

7374
@Test
75+
@Disabled("Due to model version has reached the end of its life")
7476
public void chatCompletion() {
7577

7678
var request = CohereChatRequest
@@ -95,6 +97,7 @@ public void chatCompletion() {
9597
assertThat(response.generations().get(0).text()).isNotEmpty();
9698
}
9799

100+
@Disabled("Due to model version has reached the end of its life")
98101
@Test
99102
public void chatCompletionStream() {
100103

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ public BedrockAi21Jurassic2ChatModel bedrockAi21Jurassic2ChatModel(
158158
return new BedrockAi21Jurassic2ChatModel(jurassic2ChatBedrockApi,
159159
BedrockAi21Jurassic2ChatOptions.builder()
160160
.withTemperature(0.5)
161-
.withMaxTokens(100)
162-
.withTopP(0.9)
161+
.withMaxTokens(500)
162+
// .withTopP(0.9)
163163
.build());
164164
}
165165

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatAutoConfiguration.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
*/
5252
@AutoConfiguration
5353
@EnableConfigurationProperties({ BedrockConverseProxyChatProperties.class, BedrockAwsConnectionConfiguration.class })
54-
@ConditionalOnClass({ BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class })
54+
@ConditionalOnClass({ BedrockProxyChatModel.class, BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class })
5555
@ConditionalOnProperty(prefix = BedrockConverseProxyChatProperties.CONFIG_PREFIX, name = "enabled",
5656
havingValue = "true", matchIfMissing = true)
5757
@Import(BedrockAwsConnectionConfiguration.class)

0 commit comments

Comments
 (0)