Skip to content

Commit 41eab27

Browse files
tzolovmarkpollack
authored andcommitted
Refactor High-level API function calling support
- Add ToolCall to AssistantMessage. - Rename FunctionMessage to ToolResponseMessage and add id and name fields. - Refactor OpenAiChatModel's function calling handling. - Prompt copy now copies the AssistantMessage and ToolResponseMessage contents. Other ChatModel implementations to adopt these changes in subsequent commits
1 parent e570ef5 commit 41eab27

File tree

14 files changed

+566
-194
lines changed

14 files changed

+566
-194
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/MessageToPromptConverter.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
* Converts a list of messages to a prompt for bedrock models.
2626
*
2727
* @author Christian Tzolov
28-
* @since 0.8.0
28+
* @since 1.0.0
2929
*/
3030
public class MessageToPromptConverter {
3131

@@ -75,10 +75,8 @@ public String toPrompt(List<Message> messages) {
7575
.collect(Collectors.joining(System.lineSeparator()));
7676

7777
// Related to: https://github.com/spring-projects/spring-ai/issues/404
78-
final String prompt = systemMessages + this.lineSeparator + this.lineSeparator + userMessages
79-
+ this.lineSeparator + ASSISTANT_PROMPT;
80-
81-
return prompt;
78+
return systemMessages + this.lineSeparator + this.lineSeparator + userMessages + this.lineSeparator
79+
+ ASSISTANT_PROMPT;
8280
}
8381

8482
protected String messageToString(Message message) {
@@ -89,7 +87,7 @@ protected String messageToString(Message message) {
8987
return humanPrompt + " " + message.getContent();
9088
case ASSISTANT:
9189
return assistantPrompt + " " + message.getContent();
92-
case FUNCTION:
90+
case TOOL:
9391
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
9492
}
9593

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 119 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,12 @@
1515
*/
1616
package org.springframework.ai.openai;
1717

18-
import java.util.ArrayList;
19-
import java.util.Base64;
20-
import java.util.HashMap;
21-
import java.util.HashSet;
22-
import java.util.List;
23-
import java.util.Map;
24-
import java.util.Optional;
25-
import java.util.Set;
26-
import java.util.concurrent.ConcurrentHashMap;
27-
2818
import org.slf4j.Logger;
2919
import org.slf4j.LoggerFactory;
20+
import org.springframework.ai.chat.messages.AssistantMessage;
21+
import org.springframework.ai.chat.messages.Message;
22+
import org.springframework.ai.chat.messages.MessageType;
23+
import org.springframework.ai.chat.messages.ToolResponseMessage;
3024
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3125
import org.springframework.ai.chat.metadata.RateLimit;
3226
import org.springframework.ai.chat.model.ChatModel;
@@ -36,15 +30,15 @@
3630
import org.springframework.ai.chat.prompt.ChatOptions;
3731
import org.springframework.ai.chat.prompt.Prompt;
3832
import org.springframework.ai.model.ModelOptionsUtils;
39-
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
33+
import org.springframework.ai.model.function.AbstractToolCallSupport;
4034
import org.springframework.ai.model.function.FunctionCallbackContext;
4135
import org.springframework.ai.openai.api.OpenAiApi;
4236
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
4337
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice;
4438
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionFinishReason;
4539
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
40+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ChatCompletionFunction;
4641
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.MediaContent;
47-
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;
4842
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall;
4943
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
5044
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
@@ -55,8 +49,16 @@
5549
import org.springframework.util.Assert;
5650
import org.springframework.util.CollectionUtils;
5751
import org.springframework.util.MimeType;
58-
5952
import reactor.core.publisher.Flux;
53+
import reactor.core.publisher.Mono;
54+
55+
import java.util.ArrayList;
56+
import java.util.Base64;
57+
import java.util.HashSet;
58+
import java.util.List;
59+
import java.util.Map;
60+
import java.util.Set;
61+
import java.util.concurrent.ConcurrentHashMap;
6062

6163
/**
6264
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI}
@@ -77,9 +79,7 @@
7779
* @see StreamingChatModel
7880
* @see OpenAiApi
7981
*/
80-
public class OpenAiChatModel extends
81-
AbstractFunctionCallSupport<ChatCompletionMessage, OpenAiApi.ChatCompletionRequest, ResponseEntity<ChatCompletion>>
82-
implements ChatModel {
82+
public class OpenAiChatModel extends AbstractToolCallSupport<ChatCompletion> implements ChatModel {
8383

8484
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModel.class);
8585

@@ -145,14 +145,25 @@ public ChatResponse call(Prompt prompt) {
145145

146146
return this.retryTemplate.execute(ctx -> {
147147

148-
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
148+
ResponseEntity<ChatCompletion> completionEntity = this.openAiApi.chatCompletionEntity(request);
149149

150150
var chatCompletion = completionEntity.getBody();
151+
151152
if (chatCompletion == null) {
152153
logger.warn("No chat completion returned for prompt: {}", prompt);
153154
return new ChatResponse(List.of());
154155
}
155156

157+
if (isToolFunctionCall(chatCompletion)) {
158+
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
159+
chatCompletion);
160+
// Recursively call the call method with the tool call message
161+
// conversation that contains the call responses.
162+
163+
return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions()));
164+
}
165+
166+
// Non function calling.
156167
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);
157168

158169
List<Choice> choices = chatCompletion.choices();
@@ -162,7 +173,10 @@ public ChatResponse call(Prompt prompt) {
162173
}
163174

164175
List<Generation> generations = choices.stream().map(choice -> {
165-
var generation = new Generation(choice.message().content(), toMap(chatCompletion.id(), choice));
176+
Map<String, Object> metadata = Map.of("id", chatCompletion.id(), "role",
177+
choice.message().role() != null ? choice.message().role().name() : "", "finishReason",
178+
choice.finishReason() != null ? choice.finishReason().name() : "");
179+
var generation = new Generation(choice.message().content(), metadata);
166180
if (choice.finishReason() != null) {
167181
generation = generation
168182
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
@@ -176,20 +190,6 @@ public ChatResponse call(Prompt prompt) {
176190
});
177191
}
178192

179-
private Map<String, Object> toMap(String id, ChatCompletion.Choice choice) {
180-
Map<String, Object> map = new HashMap<>();
181-
182-
var message = choice.message();
183-
if (message.role() != null) {
184-
map.put("role", message.role().name());
185-
}
186-
if (choice.finishReason() != null) {
187-
map.put("finishReason", choice.finishReason().name());
188-
}
189-
map.put("id", id);
190-
return map;
191-
}
192-
193193
@Override
194194
public Flux<ChatResponse> stream(Prompt prompt) {
195195

@@ -205,16 +205,23 @@ public Flux<ChatResponse> stream(Prompt prompt) {
205205

206206
// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
207207
// the function call handling logic.
208-
return completionChunks.map(chunk -> chunkToChatCompletion(chunk))
209-
.switchMap(
210-
cc -> handleFunctionCallOrReturnStream(request, Flux.just(ResponseEntity.of(Optional.of(cc)))))
211-
.map(ResponseEntity::getBody)
212-
.map(chatCompletion -> {
208+
return completionChunks.map(this::chunkToChatCompletion).switchMap(chatCompletion -> {
209+
210+
if (this.isToolFunctionCall(chatCompletion)) {
211+
var toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
212+
chatCompletion);
213+
// Recursively call the stream method with the tool call message
214+
// conversation that contains the call responses.
215+
return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions()));
216+
}
217+
218+
// Non function calling.
219+
return Mono.just(chatCompletion).map(chatCompletion2 -> {
213220
try {
214221
@SuppressWarnings("null")
215-
String id = chatCompletion.id();
222+
String id = chatCompletion2.id();
216223

217-
List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
224+
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> {
218225
if (choice.message().role() != null) {
219226
roleMap.putIfAbsent(id, choice.message().role().name());
220227
}
@@ -228,8 +235,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
228235
return generation;
229236
}).toList();
230237

231-
if (chatCompletion.usage() != null) {
232-
return new ChatResponse(generations, OpenAiChatResponseMetadata.from(chatCompletion));
238+
if (chatCompletion2.usage() != null) {
239+
return new ChatResponse(generations, OpenAiChatResponseMetadata.from(chatCompletion2));
233240
}
234241
else {
235242
return new ChatResponse(generations);
@@ -241,9 +248,33 @@ public Flux<ChatResponse> stream(Prompt prompt) {
241248
}
242249

243250
});
251+
});
244252
});
245253
}
246254

255+
private List<Message> handleToolCallRequests(List<Message> previousMessages, ChatCompletion chatCompletion) {
256+
257+
ChatCompletionMessage nativeAssistantMessage = this.extractAssistantMessage(chatCompletion);
258+
259+
List<AssistantMessage.ToolCall> assistantToolCalls = nativeAssistantMessage.toolCalls()
260+
.stream()
261+
.map(toolCall -> new AssistantMessage.ToolCall(toolCall.id(), "function", toolCall.function().name(),
262+
toolCall.function().arguments()))
263+
.toList();
264+
265+
AssistantMessage assistantMessage = new AssistantMessage(nativeAssistantMessage.content(), Map.of(),
266+
assistantToolCalls);
267+
268+
List<ToolResponseMessage> toolResponseMessages = this.executeFuncitons(assistantMessage);
269+
270+
// History
271+
List<Message> messages = new ArrayList<>(previousMessages);
272+
messages.add(assistantMessage);
273+
messages.addAll(toolResponseMessages);
274+
275+
return messages;
276+
}
277+
247278
/**
248279
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
249280
* @param chunk the ChatCompletionChunk to convert
@@ -252,38 +283,66 @@ public Flux<ChatResponse> stream(Prompt prompt) {
252283
private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionChunk chunk) {
253284
List<Choice> choices = chunk.choices()
254285
.stream()
255-
.map(cc -> new Choice(cc.finishReason(), cc.index(), cc.delta(), cc.logprobs()))
286+
.map(chunkChoice -> new Choice(chunkChoice.finishReason(), chunkChoice.index(), chunkChoice.delta(),
287+
chunkChoice.logprobs()))
256288
.toList();
257289

258290
return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(),
259291
chunk.systemFingerprint(), "chat.completion", chunk.usage());
260292
}
261293

294+
private ChatCompletionMessage extractAssistantMessage(ChatCompletion chatCompletion) {
295+
return chatCompletion.choices().iterator().next().message();
296+
}
297+
262298
/**
263299
* Accessible for testing.
264300
*/
265301
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
266302

267303
Set<String> functionsForThisRequest = new HashSet<>();
268304

269-
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(m -> {
270-
Object content;
271-
if (CollectionUtils.isEmpty(m.getMedia())) {
272-
content = m.getContent();
273-
}
274-
else {
275-
List<MediaContent> contentList = new ArrayList<>(List.of(new MediaContent(m.getContent())));
305+
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
306+
if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) {
307+
Object content;
308+
if (CollectionUtils.isEmpty(message.getMedia())) {
309+
content = message.getContent();
310+
}
311+
else {
312+
List<MediaContent> contentList = new ArrayList<>(List.of(new MediaContent(message.getContent())));
276313

277-
contentList.addAll(m.getMedia()
278-
.stream()
279-
.map(media -> new MediaContent(
280-
new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData()))))
281-
.toList());
314+
contentList.addAll(message.getMedia()
315+
.stream()
316+
.map(media -> new MediaContent(
317+
new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData()))))
318+
.toList());
282319

283-
content = contentList;
284-
}
320+
content = contentList;
321+
}
285322

286-
return new ChatCompletionMessage(content, ChatCompletionMessage.Role.valueOf(m.getMessageType().name()));
323+
return new ChatCompletionMessage(content,
324+
ChatCompletionMessage.Role.valueOf(message.getMessageType().name()));
325+
}
326+
else if (message.getMessageType() == MessageType.ASSISTANT) {
327+
var assistantMessage = (AssistantMessage) message;
328+
List<ToolCall> toolCalls = null;
329+
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
330+
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
331+
var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());
332+
return new ToolCall(toolCall.id(), toolCall.type(), function);
333+
}).toList();
334+
}
335+
return new ChatCompletionMessage(assistantMessage.getContent(), ChatCompletionMessage.Role.ASSISTANT,
336+
null, null, toolCalls);
337+
}
338+
else if (message.getMessageType() == MessageType.TOOL) {
339+
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
340+
return new ChatCompletionMessage(toolMessage.getContent(), ChatCompletionMessage.Role.TOOL,
341+
toolMessage.getName(), toolMessage.getId(), null);
342+
}
343+
else {
344+
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
345+
}
287346
}).toList();
288347

289348
ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);
@@ -351,66 +410,12 @@ private List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames)
351410
}
352411

353412
@Override
354-
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
355-
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
356-
357-
// Every tool-call item requires a separate function call and a response (TOOL)
358-
// message.
359-
for (ToolCall toolCall : responseMessage.toolCalls()) {
360-
361-
var functionName = toolCall.function().name();
362-
String functionArguments = toolCall.function().arguments();
363-
364-
if (!this.functionCallbackRegister.containsKey(functionName)) {
365-
throw new IllegalStateException("No function callback found for function name: " + functionName);
366-
}
367-
368-
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
369-
370-
// Add the function response to the conversation.
371-
conversationHistory
372-
.add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null));
373-
}
374-
375-
// Recursively call chatCompletionWithTools until the model doesn't call a
376-
// functions anymore.
377-
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, previousRequest.stream());
378-
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);
379-
380-
return newRequest;
381-
}
382-
383-
@Override
384-
protected List<ChatCompletionMessage> doGetUserMessages(ChatCompletionRequest request) {
385-
return request.messages();
386-
}
387-
388-
@Override
389-
protected ChatCompletionMessage doGetToolResponseMessage(ResponseEntity<ChatCompletion> chatCompletion) {
390-
return chatCompletion.getBody().choices().iterator().next().message();
391-
}
392-
393-
@Override
394-
protected ResponseEntity<ChatCompletion> doChatCompletion(ChatCompletionRequest request) {
395-
return this.openAiApi.chatCompletionEntity(request);
396-
}
397-
398-
@Override
399-
protected Flux<ResponseEntity<ChatCompletion>> doChatCompletionStream(ChatCompletionRequest request) {
400-
return this.openAiApi.chatCompletionStream(request)
401-
.map(this::chunkToChatCompletion)
402-
.map(Optional::ofNullable)
403-
.map(ResponseEntity::of);
404-
}
405-
406-
@Override
407-
protected boolean isToolFunctionCall(ResponseEntity<ChatCompletion> chatCompletion) {
408-
var body = chatCompletion.getBody();
409-
if (body == null) {
413+
protected boolean isToolFunctionCall(ChatCompletion chatCompletion) {
414+
if (chatCompletion == null) {
410415
return false;
411416
}
412417

413-
var choices = body.choices();
418+
var choices = chatCompletion.choices();
414419
if (CollectionUtils.isEmpty(choices)) {
415420
return false;
416421
}

0 commit comments

Comments
 (0)