diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 58a0062eccd..1b3b776fdc4 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -635,6 +635,10 @@ private MediaContent mapToMediaContent(Media media) { return new MediaContent( new MediaContent.InputAudio(fromAudioData(media.getData()), MediaContent.InputAudio.Format.WAV)); } + if (MimeTypeUtils.parseMimeType("application/pdf").equals(mimeType)) { + return new MediaContent(new MediaContent.InputFile(media.getName(), + this.fromMediaData(media.getMimeType(), media.getData()))); + } else { return new MediaContent( new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData()))); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 2b11927f908..d72534b4654 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -1367,14 +1367,15 @@ public record MediaContent(// @formatter:off @JsonProperty("type") String type, @JsonProperty("text") String text, @JsonProperty("image_url") ImageUrl imageUrl, - @JsonProperty("input_audio") InputAudio inputAudio) { // @formatter:on + @JsonProperty("input_audio") InputAudio inputAudio, + @JsonProperty("file") InputFile inputFile) { // @formatter:on /** * Shortcut constructor for a text content. * @param text The text content of the message. */ public MediaContent(String text) { - this("text", text, null, null); + this("text", text, null, null, null); } /** @@ -1382,7 +1383,7 @@ public MediaContent(String text) { * @param imageUrl The image content of the message. */ public MediaContent(ImageUrl imageUrl) { - this("image_url", null, imageUrl, null); + this("image_url", null, imageUrl, null, null); } /** @@ -1390,7 +1391,15 @@ public MediaContent(ImageUrl imageUrl) { * @param inputAudio The audio content of the message. */ public MediaContent(InputAudio inputAudio) { - this("input_audio", null, null, inputAudio); + this("input_audio", null, null, inputAudio, null); + } + + /** + * Shortcut constructor for a file content + * @param inputFile The file content of the message. + */ + public MediaContent(InputFile inputFile) { + this("file", null, null, null, inputFile); } /** @@ -1428,6 +1437,18 @@ public ImageUrl(String url) { } + /** + * Constructor for base64-encoded file + * + * @param filename name of the file + * @param fileData file data with format + * "data:{mimetype};base64,{base64-encoded-image-data}". + */ + public record InputFile(@JsonProperty("filename") String filename, + @JsonProperty("file_data") String fileData) { + + } + } /** diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java index f24b71f62d1..f1a91dc6060 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java @@ -18,6 +18,7 @@ import java.net.MalformedURLException; import java.net.URI; +import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; @@ -29,6 +30,8 @@ import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.util.MimeType; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.SystemMessage; @@ -126,11 +129,8 @@ public void userMessageWithMediaType() throws MalformedURLException { given(this.openAiApi.chatCompletionEntity(this.pomptCaptor.capture(), this.headersCaptor.capture())) .willReturn(Mockito.mock(ResponseEntity.class)); - URI mediaUri = URI.create("http://test"); - this.chatModel.call(new Prompt(List.of(UserMessage.builder() - .text("test message") - .media(List.of(Media.builder().mimeType(MimeTypeUtils.IMAGE_JPEG).data(mediaUri).build())) - .build()))); + this.chatModel + .call(new Prompt(List.of(UserMessage.builder().text("test message").media(this.buildMediaList()).build()))); validateComplexContent(this.pomptCaptor.getValue()); } @@ -141,11 +141,10 @@ public void streamUserMessageWithMediaType() throws MalformedURLException { given(this.openAiApi.chatCompletionStream(this.pomptCaptor.capture(), this.headersCaptor.capture())) .willReturn(this.fluxResponse); - URI mediaUrl = URI.create("http://test"); - this.chatModel.stream(new Prompt(List.of(UserMessage.builder() - .text("test message") - .media(List.of(Media.builder().mimeType(MimeTypeUtils.IMAGE_JPEG).data(mediaUrl).build())) - .build()))).subscribe(); + this.chatModel + .stream(new Prompt( + List.of(UserMessage.builder().text("test message").media(this.buildMediaList()).build()))) + .subscribe(); validateComplexContent(this.pomptCaptor.getValue()); } @@ -161,16 +160,40 @@ private void validateComplexContent(ChatCompletionRequest chatCompletionRequest) @SuppressWarnings({ "unused", "unchecked" }) List> mediaContents = (List>) userMessage.rawContent(); - assertThat(mediaContents).hasSize(2); + assertThat(mediaContents).hasSize(3); + // Assert text content Map textContent = mediaContents.get(0); assertThat(textContent.get("type")).isEqualTo("text"); assertThat(textContent.get("text")).isEqualTo("test message"); + // Assert image content Map imageContent = mediaContents.get(1); assertThat(imageContent.get("type")).isEqualTo("image_url"); assertThat(imageContent).containsKey("image_url"); + + // Assert file content + Map fileContent = mediaContents.get(2); + assertThat(fileContent.get("type")).isEqualTo("file"); + assertThat(fileContent).containsKey("file"); + assertThat(fileContent.get("file")).isInstanceOf(Map.class); + + Map fileMap = (Map) fileContent.get("file"); + assertThat(fileMap.get("file_data")).isEqualTo("data:application/pdf;base64,JVBERi0xLjc="); + } + + private List buildMediaList() { + URI imageUri = URI.create("http://test"); + Media imageMedia = Media.builder().mimeType(MimeTypeUtils.IMAGE_JPEG).data(imageUri).build(); + + byte[] pdfData = "%PDF-1.7".getBytes(StandardCharsets.UTF_8); + Media pdfMedia = Media.builder() + .mimeType(MimeType.valueOf("application/pdf")) + .data(new ByteArrayResource(pdfData)) + .build(); + + return List.of(imageMedia, pdfMedia); } }