diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 4687504e075..c437786dd5e 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -36,7 +36,6 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; import reactor.core.publisher.Sinks.EmitFailureHandler; -import reactor.core.publisher.Sinks.EmitResult; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.document.Document; @@ -524,6 +523,9 @@ public Flux stream(Prompt prompt) { }); } + public static final EmitFailureHandler DEFAULT_EMIT_FAILURE_HANDLER = EmitFailureHandler + .busyLooping(Duration.ofSeconds(10)); + /** * Invoke the model and return the response stream. * @@ -541,26 +543,19 @@ public Flux converseStream(ConverseStreamRequest converseS ConverseStreamResponseHandler.Visitor visitor = ConverseStreamResponseHandler.Visitor.builder() .onDefault(output -> { logger.debug("Received converse stream output:{}", output); - eventSink.tryEmitNext(output); + eventSink.emitNext(output, DEFAULT_EMIT_FAILURE_HANDLER); }) .build(); ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder() .onEventStream(stream -> stream.subscribe(e -> e.accept(visitor))) .onComplete(() -> { - EmitResult emitResult = eventSink.tryEmitComplete(); - - while (!emitResult.isSuccess()) { - logger.info("Emitting complete:{}", emitResult); - emitResult = eventSink.tryEmitComplete(); - } - - eventSink.emitComplete(EmitFailureHandler.busyLooping(Duration.ofSeconds(3))); + eventSink.emitComplete(DEFAULT_EMIT_FAILURE_HANDLER); logger.info("Completed streaming response."); }) .onError(error -> { logger.error("Error handling Bedrock converse stream response", error); - eventSink.tryEmitError(error); + eventSink.emitError(error, DEFAULT_EMIT_FAILURE_HANDLER); }) .build(); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java index e671e01bcb2..439c67251b8 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java @@ -32,7 +32,6 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Sinks; import reactor.core.publisher.Sinks.EmitFailureHandler; -import reactor.core.publisher.Sinks.EmitResult; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; @@ -70,6 +69,10 @@ public abstract class AbstractBedrockApi { private static final Logger logger = LoggerFactory.getLogger(AbstractBedrockApi.class); + public static final EmitFailureHandler DEFAULT_EMIT_FAILURE_HANDLER = EmitFailureHandler + .busyLooping(Duration.ofSeconds(10)); + + private final String modelId; private final ObjectMapper objectMapper; private final Region region; @@ -264,7 +267,7 @@ protected Flux internalInvocationStream(I request, Class clazz) { body = SdkBytes.fromUtf8String(this.objectMapper.writeValueAsString(request)); } catch (JsonProcessingException e) { - eventSink.tryEmitError(e); + eventSink.emitError(e, DEFAULT_EMIT_FAILURE_HANDLER); return eventSink.asFlux(); } @@ -279,16 +282,16 @@ protected Flux internalInvocationStream(I request, Class clazz) { try { logger.debug("Received chunk: " + chunk.bytes().asString(StandardCharsets.UTF_8)); SO response = this.objectMapper.readValue(chunk.bytes().asByteArray(), clazz); - eventSink.tryEmitNext(response); + eventSink.emitNext(response, DEFAULT_EMIT_FAILURE_HANDLER); } catch (Exception e) { logger.error("Failed to unmarshall", e); - eventSink.tryEmitError(e); + eventSink.emitError(e, DEFAULT_EMIT_FAILURE_HANDLER); } }) .onDefault(event -> { logger.error("Unknown or unhandled event: " + event.toString()); - eventSink.tryEmitError(new Throwable("Unknown or unhandled event: " + event.toString())); + eventSink.emitError(new Throwable("Unknown or unhandled event: " + event.toString()),DEFAULT_EMIT_FAILURE_HANDLER); }) .build(); @@ -296,18 +299,12 @@ protected Flux internalInvocationStream(I request, Class clazz) { .builder() .onComplete( () -> { - EmitResult emitResult = eventSink.tryEmitComplete(); - while (!emitResult.isSuccess()) { - System.out.println("Emitting complete:" + emitResult); - emitResult = eventSink.tryEmitComplete(); - } - eventSink.emitComplete(EmitFailureHandler.busyLooping(Duration.ofSeconds(3))); - // EmitResult emitResult = eventSink.tryEmitComplete(); - logger.debug("\nCompleted streaming response."); + eventSink.emitComplete(DEFAULT_EMIT_FAILURE_HANDLER); + logger.info("Completed streaming response."); }) .onError(error -> { logger.error("\n\nError streaming response: " + error.getMessage()); - eventSink.tryEmitError(error); + eventSink.emitError(error, DEFAULT_EMIT_FAILURE_HANDLER); }) .onEventStream(stream -> stream.subscribe( (ResponseStream e) -> e.accept(visitor))) diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java index 340b541b0f7..2dc38802a31 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java @@ -23,6 +23,7 @@ import java.util.stream.Collectors; import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import reactor.core.publisher.Flux; @@ -55,6 +56,7 @@ @SpringBootTest @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +@Disabled("COHERE_COMMAND_V14 is not supported anymore") class BedrockCohereChatModelIT { @Autowired diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java index 41861e9ecac..092fce651d3 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java @@ -19,6 +19,7 @@ import java.time.Duration; import java.util.List; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import reactor.core.publisher.Flux; @@ -71,6 +72,7 @@ public void requestBuilder() { } @Test + @Disabled("Due to model version has reached the end of its life") public void chatCompletion() { var request = CohereChatRequest @@ -95,6 +97,7 @@ public void chatCompletion() { assertThat(response.generations().get(0).text()).isNotEmpty(); } + @Disabled("Due to model version has reached the end of its life") @Test public void chatCompletionStream() { diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java index c0919cd03f4..9919e4b853a 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatModelIT.java @@ -158,8 +158,8 @@ public BedrockAi21Jurassic2ChatModel bedrockAi21Jurassic2ChatModel( return new BedrockAi21Jurassic2ChatModel(jurassic2ChatBedrockApi, BedrockAi21Jurassic2ChatOptions.builder() .withTemperature(0.5) - .withMaxTokens(100) - .withTopP(0.9) + .withMaxTokens(500) + // .withTopP(0.9) .build()); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatAutoConfiguration.java index 869885eb409..cb5bfd4c706 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatAutoConfiguration.java @@ -51,7 +51,7 @@ */ @AutoConfiguration @EnableConfigurationProperties({ BedrockConverseProxyChatProperties.class, BedrockAwsConnectionConfiguration.class }) -@ConditionalOnClass({ BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class }) +@ConditionalOnClass({ BedrockProxyChatModel.class, BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class }) @ConditionalOnProperty(prefix = BedrockConverseProxyChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) @Import(BedrockAwsConnectionConfiguration.class)