2626import java .util .Base64 ;
2727import java .util .List ;
2828import java .util .Map ;
29+ import java .util .Set ;
2930
3031import io .micrometer .observation .Observation ;
3132import io .micrometer .observation .ObservationRegistry ;
5758import software .amazon .awssdk .services .bedrockruntime .model .InferenceConfiguration ;
5859import software .amazon .awssdk .services .bedrockruntime .model .Message ;
5960import software .amazon .awssdk .services .bedrockruntime .model .S3Location ;
61+ import software .amazon .awssdk .services .bedrockruntime .model .StopReason ;
6062import software .amazon .awssdk .services .bedrockruntime .model .SystemContentBlock ;
6163import software .amazon .awssdk .services .bedrockruntime .model .Tool ;
6264import software .amazon .awssdk .services .bedrockruntime .model .ToolConfiguration ;
@@ -262,7 +264,8 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon
262264 });
263265
264266 if (ToolCallingChatOptions .isInternalToolExecutionEnabled (prompt .getOptions ()) && chatResponse != null
265- && chatResponse .hasToolCalls ()) {
267+ && chatResponse .hasToolCalls ()
268+ && chatResponse .hasFinishReasons (Set .of (StopReason .TOOL_USE .toString ()))) {
266269 var toolExecutionResult = this .toolCallingManager .executeToolCalls (prompt , chatResponse );
267270 if (toolExecutionResult .returnDirect ()) {
268271 // Return tool execution result directly to the client.
@@ -280,22 +283,6 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon
280283 return chatResponse ;
281284 }
282285
283- // private ToolCallingChatOptions buildRequestOptions(ConverseRequest request) {
284-
285- // ToolCallingChatOptions toolCallbackChatOptions = ToolCallingChatOptions.builder()
286- // .model(request.modelId())
287- // .maxTokens(request.inferenceConfig().maxTokens())
288- // .stopSequences(request.inferenceConfig().stopSequences())
289- // .temperature(request.inferenceConfig().temperature() != null
290- // ? request.inferenceConfig().temperature().doubleValue()
291- // : null)
292- // .topP(request.inferenceConfig().topP() != null ?
293- // request.inferenceConfig().topP().doubleValue() : null)
294- // .build();
295-
296- // return toolCallbackChatOptions;
297- // }
298-
299286 @ Override
300287 public ChatOptions getDefaultOptions () {
301288 return this .defaultOptions ;
@@ -708,28 +695,34 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
708695
709696 Flux <ConverseStreamOutput > response = converseStream (converseStreamRequest );
710697
711- // @formatter:off
712698 Flux <ChatResponse > chatResponses = ConverseApiUtils .toChatResponse (response , perviousChatResponse );
713699
714700 Flux <ChatResponse > chatResponseFlux = chatResponses .switchMap (chatResponse -> {
715701
716- if (ToolCallingChatOptions .isInternalToolExecutionEnabled (prompt .getOptions ()) && chatResponse .hasToolCalls ()) {
702+ if (ToolCallingChatOptions .isInternalToolExecutionEnabled (prompt .getOptions ())
703+ && chatResponse .hasToolCalls ()
704+ && chatResponse .hasFinishReasons (Set .of (StopReason .TOOL_USE .toString ()))) {
705+
717706 var toolExecutionResult = this .toolCallingManager .executeToolCalls (prompt , chatResponse );
707+
718708 if (toolExecutionResult .returnDirect ()) {
719709 // Return tool execution result directly to the client.
720- return Flux .just (ChatResponse .builder ().from (chatResponse )
721- .generations (ToolExecutionResult .buildGenerations (toolExecutionResult ))
722- .build ());
723- } else {
710+ return Flux .just (ChatResponse .builder ()
711+ .from (chatResponse )
712+ .generations (ToolExecutionResult .buildGenerations (toolExecutionResult ))
713+ .build ());
714+ }
715+ else {
724716 // Send the tool execution result back to the model.
725- return this .internalStream (new Prompt (toolExecutionResult .conversationHistory (), prompt .getOptions ()),
726- chatResponse );
717+ return this .internalStream (
718+ new Prompt (toolExecutionResult .conversationHistory (), prompt .getOptions ()),
719+ chatResponse );
727720 }
728721 }
729722 else {
730723 return Flux .just (chatResponse );
731724 }
732- })
725+ })// @formatter:off
733726 .doOnError (observation ::error )
734727 .doFinally (s -> observation .stop ())
735728 .contextWrite (ctx -> ctx .put (ObservationThreadLocalAccessor .KEY , observation ));
0 commit comments