From 8bad84871c7752d16949d7c5892513cebc18c845 Mon Sep 17 00:00:00 2001 From: TarasVovk669 Date: Wed, 7 Aug 2024 11:55:31 +0200 Subject: [PATCH] fix refusal field in ChatCompletion model --- .../java/org/springframework/ai/openai/OpenAiChatModel.java | 4 ++-- .../java/org/springframework/ai/openai/api/OpenAiApi.java | 5 +++-- .../ai/openai/api/OpenAiStreamFunctionCallingHelper.java | 3 ++- .../ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index f798323d82d..30d5257ab95 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -424,7 +424,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { }).toList(); } return List.of(new ChatCompletionMessage(assistantMessage.getContent(), - ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls)); + ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; @@ -437,7 +437,7 @@ else if (message.getMessageType() == MessageType.TOOL) { return toolMessage.getResponses() .stream() .map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(), - tr.id(), null)) + tr.id(), null, null)) .toList(); } else { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index c71cb02a5e6..f946c98d5f5 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -560,7 +560,8 @@ public record ChatCompletionMessage(// @formatter:off @JsonProperty("role") Role role, @JsonProperty("name") String name, @JsonProperty("tool_call_id") String toolCallId, - @JsonProperty("tool_calls") List toolCalls) {// @formatter:on + @JsonProperty("tool_calls") List toolCalls, + @JsonProperty("refusal") String refusal) {// @formatter:on /** * Get message content as String. @@ -582,7 +583,7 @@ public String content() { * @param role The role of the author of this message. */ public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null); + this(content, role, null, null, null, null); } /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java index b0e23ce3621..02bfd310800 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java @@ -91,6 +91,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null String name = (current.name() != null ? current.name() : previous.name()); String toolCallId = (current.toolCallId() != null ? current.toolCallId() : previous.toolCallId()); + String refusal = (current.refusal() != null ? current.refusal() : previous.refusal()); List toolCalls = new ArrayList<>(); ToolCall lastPreviousTooCall = null; @@ -120,7 +121,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti toolCalls.add(lastPreviousTooCall); } } - return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls); + return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal); } private ToolCall merge(ToolCall previous, ToolCall current) { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java index 8b224102427..1753b6eacaa 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java @@ -122,7 +122,7 @@ public void toolFunctionCall() { // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), - Role.TOOL, functionName, toolCall.id(), null)); + Role.TOOL, functionName, toolCall.id(), null, null)); } }