Skip to content

Commit 893a8df

Browse files
committed
added test to validate bug fix
1 parent fe3db2a commit 893a8df

File tree

3 files changed

+72
-9
lines changed

3 files changed

+72
-9
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.List;
2323
import java.util.Map;
2424
import java.util.Set;
25+
import java.util.concurrent.atomic.AtomicReference;
2526
import java.util.stream.Collectors;
2627

2728
import io.micrometer.observation.Observation;
@@ -280,15 +281,21 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
280281
Flux<ChatCompletionResponse> response = this.anthropicApi.chatCompletionStream(request);
281282

282283
// @formatter:off
284+
AtomicReference<String> toolCallId = new AtomicReference<>("");
283285
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
284286
AnthropicApi.Usage usage = chatCompletionResponse.usage();
285287
Usage currentChatResponseUsage = usage != null ? AnthropicUsage.from(chatCompletionResponse.usage()) : new EmptyUsage();
286288
Usage accumulatedUsage = UsageUtils.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
287289
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);
288290

289291
if (!isProxyToolCalls(prompt, this.defaultOptions) && this.isToolCall(chatResponse, Set.of("tool_use"))) {
290-
var toolCallConversation = handleToolCalls(prompt, chatResponse);
291-
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
292+
var toolCallConversation = prompt.getInstructions();
293+
if (toolCallId.get().equalsIgnoreCase(chatResponse.getMetadata().getId())) {
294+
toolCallConversation = handleToolCalls(prompt, chatResponse);
295+
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
296+
} else {
297+
toolCallId.set(chatResponse.getMetadata().getId());
298+
}
292299
}
293300

294301
return Mono.just(chatResponse);
@@ -493,7 +500,7 @@ private List<AnthropicApi.Tool> getFunctionTools(Set<String> functionNames) {
493500
}).toList();
494501
}
495502

496-
private ChatOptions buildRequestOptions(AnthropicApi.ChatCompletionRequest request) {
503+
private ChatOptions buildRequestOptions(ChatCompletionRequest request) {
497504
return ChatOptions.builder()
498505
.model(request.model())
499506
.maxTokens(request.maxTokens())

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,25 @@
1717
package org.springframework.ai.anthropic;
1818

1919
import java.io.IOException;
20-
import java.util.ArrayList;
21-
import java.util.Arrays;
22-
import java.util.List;
23-
import java.util.Map;
20+
import java.util.*;
2421
import java.util.stream.Collectors;
2522

23+
import ch.qos.logback.classic.Level;
24+
import ch.qos.logback.classic.LoggerContext;
25+
import ch.qos.logback.classic.spi.ILoggingEvent;
26+
import ch.qos.logback.core.Appender;
27+
import ch.qos.logback.core.AppenderBase;
28+
import ch.qos.logback.core.read.ListAppender;
29+
import org.junit.Rule;
2630
import org.junit.jupiter.api.Test;
2731
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
32+
import org.junit.jupiter.api.extension.ExtendWith;
2833
import org.junit.jupiter.params.ParameterizedTest;
2934
import org.junit.jupiter.params.provider.ValueSource;
3035
import org.slf4j.Logger;
3136
import org.slf4j.LoggerFactory;
37+
import org.springframework.boot.test.system.OutputCaptureExtension;
38+
import org.springframework.boot.test.system.OutputCaptureRule;
3239
import reactor.core.publisher.Flux;
3340

3441
import org.springframework.ai.anthropic.api.AnthropicApi;
@@ -336,10 +343,17 @@ void streamFunctionCallUsageTest() {
336343

337344
List<Message> messages = new ArrayList<>(List.of(userMessage));
338345

346+
var mockService = new MockWeatherService();
347+
348+
MemoryAppender appender = new MemoryAppender();
349+
appender.setContext((LoggerContext) LoggerFactory.getILoggerFactory());
350+
MockWeatherService.log.addAppender(appender);
351+
appender.start();
352+
339353
var promptOptions = AnthropicChatOptions.builder()
340354
.model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
341355
.functionCallbacks(List.of(FunctionCallback.builder()
342-
.function("getCurrentWeather", new MockWeatherService())
356+
.function("getCurrentWeather", mockService)
343357
.description(
344358
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
345359
.inputType(MockWeatherService.Request.class)
@@ -352,9 +366,12 @@ void streamFunctionCallUsageTest() {
352366

353367
logger.info("Response: {}", chatResponse);
354368
Usage usage = chatResponse.getMetadata().getUsage();
369+
appender.stop();
355370

356371
assertThat(usage).isNotNull();
372+
assertThat(appender.getLoggedEvents().size()).isEqualTo(3);
357373
assertThat(usage.getTotalTokens()).isLessThan(4000).isGreaterThan(1800);
374+
358375
}
359376

360377
@Test
@@ -417,4 +434,39 @@ public AnthropicChatModel openAiChatModel(AnthropicApi api) {
417434

418435
}
419436

437+
public static class MemoryAppender extends ListAppender<ILoggingEvent> {
438+
439+
public void reset() {
440+
this.list.clear();
441+
}
442+
443+
public boolean contains(String string, Level level) {
444+
return this.list.stream()
445+
.anyMatch(event -> event.toString().contains(string) && event.getLevel().equals(level));
446+
}
447+
448+
public int countEventsForLogger(String loggerName) {
449+
return (int) this.list.stream().filter(event -> event.getLoggerName().contains(loggerName)).count();
450+
}
451+
452+
public List<ILoggingEvent> search(String string) {
453+
return this.list.stream().filter(event -> event.toString().contains(string)).collect(Collectors.toList());
454+
}
455+
456+
public List<ILoggingEvent> search(String string, Level level) {
457+
return this.list.stream()
458+
.filter(event -> event.toString().contains(string) && event.getLevel().equals(level))
459+
.collect(Collectors.toList());
460+
}
461+
462+
public int getSize() {
463+
return this.list.size();
464+
}
465+
466+
public List<ILoggingEvent> getLoggedEvents() {
467+
return Collections.unmodifiableList(this.list);
468+
}
469+
470+
}
471+
420472
}

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/MockWeatherService.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,24 @@
1818

1919
import java.util.function.Function;
2020

21+
import ch.qos.logback.classic.Logger;
2122
import com.fasterxml.jackson.annotation.JsonClassDescription;
2223
import com.fasterxml.jackson.annotation.JsonInclude;
2324
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2425
import com.fasterxml.jackson.annotation.JsonProperty;
2526
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
27+
import org.slf4j.LoggerFactory;
2628

2729
/**
2830
* @author Christian Tzolov
2931
*/
3032
public class MockWeatherService implements Function<MockWeatherService.Request, MockWeatherService.Response> {
3133

34+
public static final Logger log = (Logger) LoggerFactory.getLogger(MockWeatherService.class.getName());
35+
3236
@Override
3337
public Response apply(Request request) {
34-
38+
log.info("Weather Request: {}", request.toString());
3539
double temperature = 0;
3640
if (request.location().contains("Paris")) {
3741
temperature = 15;

0 commit comments

Comments
 (0)