2020import java .util .HashSet ;
2121import java .util .List ;
2222import java .util .Map ;
23- import java .util .Optional ;
2423import java .util .Set ;
2524import java .util .stream .Collectors ;
2625
3433import org .springframework .ai .anthropic .api .AnthropicApi .ContentBlock .ContentBlockType ;
3534import org .springframework .ai .anthropic .api .AnthropicApi .Role ;
3635import org .springframework .ai .anthropic .metadata .AnthropicChatResponseMetadata ;
36+ import org .springframework .ai .chat .messages .AssistantMessage ;
37+ import org .springframework .ai .chat .messages .Message ;
3738import org .springframework .ai .chat .messages .MessageType ;
39+ import org .springframework .ai .chat .messages .ToolResponseMessage ;
3840import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
3941import org .springframework .ai .chat .model .ChatModel ;
4042import org .springframework .ai .chat .model .ChatResponse ;
4143import org .springframework .ai .chat .model .Generation ;
4244import org .springframework .ai .chat .prompt .ChatOptions ;
4345import org .springframework .ai .chat .prompt .Prompt ;
4446import org .springframework .ai .model .ModelOptionsUtils ;
45- import org .springframework .ai .model .function .AbstractFunctionCallSupport ;
47+ import org .springframework .ai .model .function .AbstractToolCallSupport ;
4648import org .springframework .ai .model .function .FunctionCallbackContext ;
4749import org .springframework .ai .retry .RetryUtils ;
4850import org .springframework .http .ResponseEntity ;
4951import org .springframework .retry .support .RetryTemplate ;
5052import org .springframework .util .Assert ;
5153import org .springframework .util .CollectionUtils ;
54+ import org .springframework .util .StringUtils ;
5255
5356import reactor .core .publisher .Flux ;
57+ import reactor .core .publisher .Mono ;
5458
5559/**
5660 * The {@link ChatModel} implementation for the Anthropic service.
6064 * @author Mariusz Bernacki
6165 * @since 1.0.0
6266 */
63- public class AnthropicChatModel extends
64- AbstractFunctionCallSupport <AnthropicApi .AnthropicMessage , AnthropicApi .ChatCompletionRequest , ResponseEntity <AnthropicApi .ChatCompletionResponse >>
65- implements ChatModel {
67+ public class AnthropicChatModel extends AbstractToolCallSupport <ChatCompletionResponse > implements ChatModel {
6668
6769 private static final Logger logger = LoggerFactory .getLogger (AnthropicChatModel .class );
6870
69- public static final String DEFAULT_MODEL_NAME = AnthropicApi .ChatModel .CLAUDE_3_OPUS .getValue ();
71+ public static final String DEFAULT_MODEL_NAME = AnthropicApi .ChatModel .CLAUDE_3_5_SONNET .getValue ();
7072
7173 public static final Integer DEFAULT_MAX_TOKENS = 500 ;
7274
@@ -148,7 +150,14 @@ public ChatResponse call(Prompt prompt) {
148150 ChatCompletionRequest request = createRequest (prompt , false );
149151
150152 return this .retryTemplate .execute (ctx -> {
151- ResponseEntity <ChatCompletionResponse > completionEntity = this .callWithFunctionSupport (request );
153+ ResponseEntity <ChatCompletionResponse > completionEntity = this .anthropicApi .chatCompletionEntity (request );
154+
155+ if (this .isToolFunctionCall (completionEntity .getBody ())) {
156+ List <Message > toolCallMessageConversation = this .handleToolCallRequests (prompt .getInstructions (),
157+ completionEntity .getBody ());
158+ return this .call (new Prompt (toolCallMessageConversation , prompt .getOptions ()));
159+ }
160+
152161 return toChatResponse (completionEntity .getBody ());
153162 });
154163 }
@@ -162,14 +171,52 @@ public Flux<ChatResponse> stream(Prompt prompt) {
162171
163172 Flux <ChatCompletionResponse > response = this .anthropicApi .chatCompletionStream (request );
164173
165- return response
166- .switchMap (chatCompletionResponse -> handleFunctionCallOrReturnStream (request ,
167- Flux .just (ResponseEntity .of (Optional .of (chatCompletionResponse )))))
168- .map (ResponseEntity ::getBody )
169- .map (this ::toChatResponse );
174+ return response .switchMap (chatCompletionResponse -> {
175+
176+ if (this .isToolFunctionCall (chatCompletionResponse )) {
177+ List <Message > toolCallMessageConversation = this .handleToolCallRequests (prompt .getInstructions (),
178+ chatCompletionResponse );
179+ return this .stream (new Prompt (toolCallMessageConversation , prompt .getOptions ()));
180+ }
181+
182+ return Mono .just (chatCompletionResponse ).map (this ::toChatResponse );
183+ });
170184 });
171185 }
172186
187+ private List <Message > handleToolCallRequests (List <Message > previousMessages ,
188+ ChatCompletionResponse chatCompletionResponse ) {
189+
190+ AnthropicMessage anthropicAssistantMessage = new AnthropicMessage (chatCompletionResponse .content (),
191+ Role .ASSISTANT );
192+
193+ List <ContentBlock > toolToUseList = anthropicAssistantMessage .content ()
194+ .stream ()
195+ .filter (c -> c .type () == ContentBlock .ContentBlockType .TOOL_USE )
196+ .toList ();
197+
198+ List <AssistantMessage .ToolCall > toolCalls = new ArrayList <>();
199+
200+ for (ContentBlock toolToUse : toolToUseList ) {
201+
202+ var functionCallId = toolToUse .id ();
203+ var functionName = toolToUse .name ();
204+ var functionArguments = ModelOptionsUtils .toJsonString (toolToUse .input ());
205+
206+ toolCalls .add (new AssistantMessage .ToolCall (functionCallId , "function" , functionName , functionArguments ));
207+ }
208+
209+ AssistantMessage assistantMessage = new AssistantMessage ("" , Map .of (), toolCalls );
210+ ToolResponseMessage toolResponseMessage = this .executeFuncitons (assistantMessage );
211+
212+ // History
213+ List <Message > toolCallMessageConversation = new ArrayList <>(previousMessages );
214+ toolCallMessageConversation .add (assistantMessage );
215+ toolCallMessageConversation .add (toolResponseMessage );
216+
217+ return toolCallMessageConversation ;
218+ }
219+
173220 private ChatResponse toChatResponse (ChatCompletionResponse chatCompletion ) {
174221 if (chatCompletion == null ) {
175222 logger .warn ("Null chat completion returned" );
@@ -203,18 +250,45 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
203250
204251 List <AnthropicMessage > userMessages = prompt .getInstructions ()
205252 .stream ()
206- .filter (m -> m .getMessageType () != MessageType .SYSTEM )
207- .map (m -> {
208- List <ContentBlock > contents = new ArrayList <>(List .of (new ContentBlock (m .getContent ())));
209- if (!CollectionUtils .isEmpty (m .getMedia ())) {
210- List <ContentBlock > mediaContent = m .getMedia ()
253+ .filter (message -> message .getMessageType () != MessageType .SYSTEM )
254+ .map (message -> {
255+ if (message .getMessageType () == MessageType .USER ) {
256+ List <ContentBlock > contents = new ArrayList <>(List .of (new ContentBlock (message .getContent ())));
257+ if (!CollectionUtils .isEmpty (message .getMedia ())) {
258+ List <ContentBlock > mediaContent = message .getMedia ()
259+ .stream ()
260+ .map (media -> new ContentBlock (media .getMimeType ().toString (),
261+ this .fromMediaData (media .getData ())))
262+ .toList ();
263+ contents .addAll (mediaContent );
264+ }
265+ return new AnthropicMessage (contents , Role .valueOf (message .getMessageType ().name ()));
266+ }
267+ else if (message .getMessageType () == MessageType .ASSISTANT ) {
268+ AssistantMessage assistantMessage = (AssistantMessage ) message ;
269+ List <ContentBlock > contentBlocks = new ArrayList <>();
270+ if (StringUtils .hasText (message .getContent ())) {
271+ contentBlocks .add (new ContentBlock (message .getContent ()));
272+ }
273+ if (!CollectionUtils .isEmpty (assistantMessage .getToolCalls ())) {
274+ for (AssistantMessage .ToolCall toolCall : assistantMessage .getToolCalls ()) {
275+ contentBlocks .add (new ContentBlock (ContentBlockType .TOOL_USE , toolCall .id (),
276+ toolCall .name (), ModelOptionsUtils .jsonToMap (toolCall .arguments ())));
277+ }
278+ }
279+ return new AnthropicMessage (contentBlocks , Role .ASSISTANT );
280+ }
281+ else if (message .getMessageType () == MessageType .TOOL ) {
282+ List <ContentBlock > toolResponses = ((ToolResponseMessage ) message ).getResponses ()
211283 .stream ()
212- .map (media -> new ContentBlock (media . getMimeType (). toString (),
213- this . fromMediaData ( media . getData () )))
284+ .map (toolResponse -> new ContentBlock (ContentBlockType . TOOL_RESULT , toolResponse . id (),
285+ toolResponse . responseData ( )))
214286 .toList ();
215- contents .addAll (mediaContent );
287+ return new AnthropicMessage (toolResponses , Role .USER );
288+ }
289+ else {
290+ throw new IllegalArgumentException ("Unsupported message type: " + message .getMessageType ());
216291 }
217- return new AnthropicMessage (contents , Role .valueOf (m .getMessageType ().name ()));
218292 })
219293 .toList ();
220294
@@ -265,74 +339,17 @@ private List<AnthropicApi.Tool> getFunctionTools(Set<String> functionNames) {
265339 }).toList ();
266340 }
267341
268- @ Override
269- protected ChatCompletionRequest doCreateToolResponseRequest (ChatCompletionRequest previousRequest ,
270- AnthropicMessage responseMessage , List <AnthropicMessage > conversationHistory ) {
271-
272- List <ContentBlock > toolToUseList = responseMessage .content ()
273- .stream ()
274- .filter (c -> c .type () == ContentBlock .ContentBlockType .TOOL_USE )
275- .toList ();
276-
277- List <ContentBlock > toolResults = new ArrayList <>();
278-
279- for (ContentBlock toolToUse : toolToUseList ) {
280-
281- var functionCallId = toolToUse .id ();
282- var functionName = toolToUse .name ();
283- var functionArguments = toolToUse .input ();
284-
285- if (!this .functionCallbackRegister .containsKey (functionName )) {
286- throw new IllegalStateException ("No function callback found for function name: " + functionName );
287- }
288-
289- String functionResponse = this .functionCallbackRegister .get (functionName )
290- .call (ModelOptionsUtils .toJsonString (functionArguments ));
291-
292- toolResults .add (new ContentBlock (ContentBlockType .TOOL_RESULT , functionCallId , functionResponse ));
293- }
294-
295- // Add the function response to the conversation.
296- conversationHistory .add (new AnthropicMessage (toolResults , Role .USER ));
297-
298- // Recursively call chatCompletionWithTools until the model doesn't call a
299- // functions anymore.
300- return ChatCompletionRequest .from (previousRequest ).withMessages (conversationHistory ).build ();
301- }
302-
303- @ Override
304- protected List <AnthropicMessage > doGetUserMessages (ChatCompletionRequest request ) {
305- return request .messages ();
306- }
307-
308- @ Override
309- protected AnthropicMessage doGetToolResponseMessage (ResponseEntity <ChatCompletionResponse > response ) {
310- return new AnthropicMessage (response .getBody ().content (), Role .ASSISTANT );
311- }
312-
313- @ Override
314- protected ResponseEntity <ChatCompletionResponse > doChatCompletion (ChatCompletionRequest request ) {
315- return this .anthropicApi .chatCompletionEntity (request );
316- }
317-
318342 @ SuppressWarnings ("null" )
319343 @ Override
320- protected boolean isToolFunctionCall (ResponseEntity < ChatCompletionResponse > response ) {
321- if (response == null || response . getBody () == null || CollectionUtils .isEmpty (response . getBody () .content ())) {
344+ protected boolean isToolFunctionCall (ChatCompletionResponse response ) {
345+ if (response == null || CollectionUtils .isEmpty (response .content ())) {
322346 return false ;
323347 }
324- return response .getBody ()
325- .content ()
348+ return response .content ()
326349 .stream ()
327350 .anyMatch (content -> content .type () == ContentBlock .ContentBlockType .TOOL_USE );
328351 }
329352
330- @ Override
331- protected Flux <ResponseEntity <ChatCompletionResponse >> doChatCompletionStream (ChatCompletionRequest request ) {
332-
333- return this .anthropicApi .chatCompletionStream (request ).map (Optional ::ofNullable ).map (ResponseEntity ::of );
334- }
335-
336353 @ Override
337354 public ChatOptions getDefaultOptions () {
338355 return AnthropicChatOptions .fromOptions (this .defaultOptions );
0 commit comments