diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 50dac57a6da..7d0ba4e652c 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -473,7 +473,7 @@ else if (message.getMessageType() == MessageType.TOOL) { updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(((FunctionCallingOptions) prompt.getOptions()), FunctionCallingOptions.class, OpenAiChatOptions.class); } - else if (prompt.getOptions() instanceof OpenAiChatOptions) { + else { updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, OpenAiChatOptions.class); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index a8e5c54696d..6b1a4a59548 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -33,6 +33,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; @@ -441,10 +442,11 @@ void multiModalityInputAudio(String modelName) { List.of(new Media(MimeTypeUtils.parseMimeType("audio/mp3"), audioResource))); ChatResponse response = chatModel - .call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build())); + .call(new Prompt(List.of(userMessage), ChatOptionsBuilder.builder().withModel(modelName).build())); logger.info(response.getResult().getOutput().getContent()); assertThat(response.getResult().getOutput().getContent()).containsIgnoringCase("hobbits"); + assertThat(response.getMetadata().getModel()).containsIgnoringCase(modelName); } @ParameterizedTest(name = "{0} : {displayName} ")