Skip to content

Commit e6fbc28

Browse files
mmazurkevichilayaperumalg
authored andcommitted
[anthropic] fix issue #1370 with tool call duplication
Without this fix during the stream event handling when `EventType.MESSAGE_STOP` occurs, the latest content block was resent again and it caused to the additional tool call(if it was the latest event) [anthropic] Replace `switchMap` -> `flatMap` to avoid cancellation of the original request Previously, internalStream used switchMap to process ChatCompletionResponses, which caused the active stream (including potential recursive calls) to be canceled whenever a new response arrived. This led to incomplete processing of streaming tool calls and unexpected behavior when handling tool_use events. Replaced switchMap with flatMap to ensure that each response is fully processed without being interrupted, allowing recursive internalStream calls to complete as expected. Without this fix during the stream event handling when `EventType.MESSAGE_STOP` occurs, the latest content block was resent again and it caused to the additional tool call(if it was the latest event) Signed-off-by: Mikhail Mazurkevich <[email protected]>
1 parent b5c3216 commit e6fbc28

File tree

3 files changed

+70
-2
lines changed

3 files changed

+70
-2
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
253253
this.getAdditionalHttpHeaders(prompt));
254254

255255
// @formatter:off
256-
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
256+
Flux<ChatResponse> chatResponseFlux = response.flatMap(chatCompletionResponse -> {
257257
AnthropicApi.Usage usage = chatCompletionResponse.usage();
258258
Usage currentChatResponseUsage = usage != null ? this.getDefaultUsage(chatCompletionResponse.usage()) : new EmptyUsage();
259259
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,14 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) {
179179
}
180180
}
181181
else if (event.type().equals(EventType.MESSAGE_STOP)) {
182-
// pass through
182+
// Don't return the latest Content block as it was before. Instead, return it
183+
// with an updated event type and general information like: model, message
184+
// type, id and usage
185+
contentBlockReference.get()
186+
.withType(event.type().name())
187+
.withContent(List.of())
188+
.withStopReason(null)
189+
.withStopSequence(null);
183190
}
184191
else {
185192
contentBlockReference.get().withType(event.type().name()).withContent(List.of());

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.anthropic.api;
1818

19+
import java.util.ArrayList;
1920
import java.util.List;
2021

2122
import org.junit.jupiter.api.Test;
@@ -27,6 +28,7 @@
2728
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse;
2829
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock;
2930
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
31+
import org.springframework.ai.model.ModelOptionsUtils;
3032
import org.springframework.http.ResponseEntity;
3133

3234
import static org.assertj.core.api.Assertions.assertThat;
@@ -42,6 +44,24 @@ public class AnthropicApiIT {
4244

4345
AnthropicApi anthropicApi = AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build();
4446

47+
List<AnthropicApi.Tool> tools = List.of(new AnthropicApi.Tool("getCurrentWeather",
48+
"Get the weather in location. Return temperature in 30°F or 30°C format.", ModelOptionsUtils.jsonToMap("""
49+
{
50+
"type": "object",
51+
"properties": {
52+
"location": {
53+
"type": "string",
54+
"description": "The city and state e.g. San Francisco, CA"
55+
},
56+
"unit": {
57+
"type": "string",
58+
"enum": ["C", "F"]
59+
}
60+
},
61+
"required": ["location", "unit"]
62+
}
63+
""")));
64+
4565
@Test
4666
void chatCompletionEntity() {
4767

@@ -106,6 +126,47 @@ void chatCompletionStream() {
106126
bla.stream().forEach(r -> System.out.println(r));
107127
}
108128

129+
@Test
130+
void chatCompletionStreamWithToolCall() {
131+
List<AnthropicMessage> messageConversation = new ArrayList<>();
132+
133+
AnthropicMessage chatCompletionMessage = new AnthropicMessage(
134+
List.of(new ContentBlock("What's the weather like in San Francisco? Show the temperature in Celsius.")),
135+
Role.USER);
136+
137+
messageConversation.add(chatCompletionMessage);
138+
139+
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
140+
.model(AnthropicApi.ChatModel.CLAUDE_3_OPUS)
141+
.messages(messageConversation)
142+
.maxTokens(1500)
143+
.stream(true)
144+
.temperature(0.8)
145+
.tools(tools)
146+
.build();
147+
148+
List<ChatCompletionResponse> responses = this.anthropicApi.chatCompletionStream(chatCompletionRequest)
149+
.collectList()
150+
.block();
151+
152+
// Check that tool uses response returned only once
153+
List<ChatCompletionResponse> toolCompletionResponses = responses.stream()
154+
.filter(r -> r.stopReason() != null && r.stopReason().equals(ContentBlock.Type.TOOL_USE.value))
155+
.toList();
156+
assertThat(toolCompletionResponses).size().isEqualTo(1);
157+
List<ContentBlock> toolContentBlocks = toolCompletionResponses.get(0).content();
158+
assertThat(toolContentBlocks).size().isEqualTo(1);
159+
ContentBlock toolContentBlock = toolContentBlocks.get(0);
160+
assertThat(toolContentBlock.type()).isEqualTo(ContentBlock.Type.TOOL_USE);
161+
assertThat(toolContentBlock.name()).isEqualTo("getCurrentWeather");
162+
163+
// Check that message stop response also returned
164+
List<ChatCompletionResponse> messageStopEvents = responses.stream()
165+
.filter(r -> r.type().equals(AnthropicApi.EventType.MESSAGE_STOP.name()))
166+
.toList();
167+
assertThat(messageStopEvents).size().isEqualTo(1);
168+
}
169+
109170
@Test
110171
void chatCompletionStreamError() {
111172
AnthropicMessage chatCompletionMessage = new AnthropicMessage(List.of(new ContentBlock("Tell me a Joke?")),

0 commit comments

Comments
 (0)