Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.core.publisher.Sinks.EmitFailureHandler;
import reactor.core.publisher.Sinks.EmitResult;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.document.Document;
Expand Down Expand Up @@ -524,6 +523,9 @@ public Flux<ChatResponse> stream(Prompt prompt) {
});
}

public static final EmitFailureHandler DEFAULT_EMIT_FAILURE_HANDLER = EmitFailureHandler
.busyLooping(Duration.ofSeconds(10));

/**
* Invoke the model and return the response stream.
*
Expand All @@ -541,26 +543,19 @@ public Flux<ConverseStreamOutput> converseStream(ConverseStreamRequest converseS
ConverseStreamResponseHandler.Visitor visitor = ConverseStreamResponseHandler.Visitor.builder()
.onDefault(output -> {
logger.debug("Received converse stream output:{}", output);
eventSink.tryEmitNext(output);
eventSink.emitNext(output, DEFAULT_EMIT_FAILURE_HANDLER);
})
.build();

ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder()
.onEventStream(stream -> stream.subscribe(e -> e.accept(visitor)))
.onComplete(() -> {
EmitResult emitResult = eventSink.tryEmitComplete();

while (!emitResult.isSuccess()) {
logger.info("Emitting complete:{}", emitResult);
emitResult = eventSink.tryEmitComplete();
}

eventSink.emitComplete(EmitFailureHandler.busyLooping(Duration.ofSeconds(3)));
eventSink.emitComplete(DEFAULT_EMIT_FAILURE_HANDLER);
logger.info("Completed streaming response.");
})
.onError(error -> {
logger.error("Error handling Bedrock converse stream response", error);
eventSink.tryEmitError(error);
eventSink.emitError(error, DEFAULT_EMIT_FAILURE_HANDLER);
})
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;
import reactor.core.publisher.Sinks.EmitFailureHandler;
import reactor.core.publisher.Sinks.EmitResult;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
Expand Down Expand Up @@ -70,6 +69,10 @@ public abstract class AbstractBedrockApi<I, O, SO> {

private static final Logger logger = LoggerFactory.getLogger(AbstractBedrockApi.class);

public static final EmitFailureHandler DEFAULT_EMIT_FAILURE_HANDLER = EmitFailureHandler
.busyLooping(Duration.ofSeconds(10));


private final String modelId;
private final ObjectMapper objectMapper;
private final Region region;
Expand Down Expand Up @@ -264,7 +267,7 @@ protected Flux<SO> internalInvocationStream(I request, Class<SO> clazz) {
body = SdkBytes.fromUtf8String(this.objectMapper.writeValueAsString(request));
}
catch (JsonProcessingException e) {
eventSink.tryEmitError(e);
eventSink.emitError(e, DEFAULT_EMIT_FAILURE_HANDLER);
return eventSink.asFlux();
}

Expand All @@ -279,35 +282,29 @@ protected Flux<SO> internalInvocationStream(I request, Class<SO> clazz) {
try {
logger.debug("Received chunk: " + chunk.bytes().asString(StandardCharsets.UTF_8));
SO response = this.objectMapper.readValue(chunk.bytes().asByteArray(), clazz);
eventSink.tryEmitNext(response);
eventSink.emitNext(response, DEFAULT_EMIT_FAILURE_HANDLER);
}
catch (Exception e) {
logger.error("Failed to unmarshall", e);
eventSink.tryEmitError(e);
eventSink.emitError(e, DEFAULT_EMIT_FAILURE_HANDLER);
}
})
.onDefault(event -> {
logger.error("Unknown or unhandled event: " + event.toString());
eventSink.tryEmitError(new Throwable("Unknown or unhandled event: " + event.toString()));
eventSink.emitError(new Throwable("Unknown or unhandled event: " + event.toString()),DEFAULT_EMIT_FAILURE_HANDLER);
})
.build();

InvokeModelWithResponseStreamResponseHandler responseHandler = InvokeModelWithResponseStreamResponseHandler
.builder()
.onComplete(
() -> {
EmitResult emitResult = eventSink.tryEmitComplete();
while (!emitResult.isSuccess()) {
System.out.println("Emitting complete:" + emitResult);
emitResult = eventSink.tryEmitComplete();
}
eventSink.emitComplete(EmitFailureHandler.busyLooping(Duration.ofSeconds(3)));
// EmitResult emitResult = eventSink.tryEmitComplete();
logger.debug("\nCompleted streaming response.");
eventSink.emitComplete(DEFAULT_EMIT_FAILURE_HANDLER);
logger.info("Completed streaming response.");
})
.onError(error -> {
logger.error("\n\nError streaming response: " + error.getMessage());
eventSink.tryEmitError(error);
eventSink.emitError(error, DEFAULT_EMIT_FAILURE_HANDLER);
})
.onEventStream(stream -> stream.subscribe(
(ResponseStream e) -> e.accept(visitor)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.stream.Collectors;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -55,6 +56,7 @@
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
@Disabled("COHERE_COMMAND_V14 is not supported anymore")
class BedrockCohereChatModelIT {

@Autowired
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.time.Duration;
import java.util.List;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -71,6 +72,7 @@ public void requestBuilder() {
}

@Test
@Disabled("Due to model version has reached the end of its life")
public void chatCompletion() {

var request = CohereChatRequest
Expand All @@ -95,6 +97,7 @@ public void chatCompletion() {
assertThat(response.generations().get(0).text()).isNotEmpty();
}

@Disabled("Due to model version has reached the end of its life")
@Test
public void chatCompletionStream() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ public BedrockAi21Jurassic2ChatModel bedrockAi21Jurassic2ChatModel(
return new BedrockAi21Jurassic2ChatModel(jurassic2ChatBedrockApi,
BedrockAi21Jurassic2ChatOptions.builder()
.withTemperature(0.5)
.withMaxTokens(100)
.withTopP(0.9)
.withMaxTokens(500)
// .withTopP(0.9)
.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
*/
@AutoConfiguration
@EnableConfigurationProperties({ BedrockConverseProxyChatProperties.class, BedrockAwsConnectionConfiguration.class })
@ConditionalOnClass({ BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class })
@ConditionalOnClass({ BedrockProxyChatModel.class, BedrockRuntimeClient.class, BedrockRuntimeAsyncClient.class })
@ConditionalOnProperty(prefix = BedrockConverseProxyChatProperties.CONFIG_PREFIX, name = "enabled",
havingValue = "true", matchIfMissing = true)
@Import(BedrockAwsConnectionConfiguration.class)
Expand Down