From 0d864aae3583203097ff2b7e7bb89c391c3a35ae Mon Sep 17 00:00:00 2001 From: Sun Yuhan Date: Tue, 5 Aug 2025 09:28:02 +0800 Subject: [PATCH] refactor: GH-3998 Refactor `org.springframework.ai.ollama.OllamaChatModel#ollamaChatRequest` to support custom implementations of `AbstractMessage` and align with other `ChatModel` implementations. Signed-off-by: Sun Yuhan --- .../ai/ollama/OllamaChatModel.java | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index c6bd6c2676e..4d736a9a3be 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -28,13 +28,10 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.*; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.DefaultUsage; @@ -441,18 +438,21 @@ Prompt buildRequestPrompt(Prompt prompt) { OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) { List ollamaMessages = prompt.getInstructions().stream().map(message -> { - if (message instanceof UserMessage userMessage) { + if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) { var messageBuilder = OllamaApi.Message.builder(Role.USER).content(message.getText()); - if (!CollectionUtils.isEmpty(userMessage.getMedia())) { - messageBuilder.images( - userMessage.getMedia().stream().map(media -> this.fromMediaData(media.getData())).toList()); + if (message instanceof UserMessage userMessage) { + if (!CollectionUtils.isEmpty(userMessage.getMedia())) { + messageBuilder.images(userMessage.getMedia() + .stream() + .map(media -> this.fromMediaData(media.getData())) + .toList()); + } } + return List.of(messageBuilder.build()); } - else if (message instanceof SystemMessage systemMessage) { - return List.of(OllamaApi.Message.builder(Role.SYSTEM).content(systemMessage.getText()).build()); - } - else if (message instanceof AssistantMessage assistantMessage) { + else if (message.getMessageType() == MessageType.ASSISTANT) { + var assistantMessage = (AssistantMessage) message; List toolCalls = null; if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> { @@ -467,7 +467,8 @@ else if (message instanceof AssistantMessage assistantMessage) { .toolCalls(toolCalls) .build()); } - else if (message instanceof ToolResponseMessage toolMessage) { + else if (message.getMessageType() == MessageType.TOOL) { + ToolResponseMessage toolMessage = (ToolResponseMessage) message; return toolMessage.getResponses() .stream() .map(tr -> OllamaApi.Message.builder(Role.TOOL).content(tr.responseData()).build())