Skip to content

Commit 6168d3a

Browse files
jharnack2realJ3H
authored andcommitted
feat(openai): add support for pdf files as media
Signed-off-by: Jan-Eric Harnack <[email protected]>
1 parent 81137ca commit 6168d3a

File tree

3 files changed

+63
-15
lines changed

3 files changed

+63
-15
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,10 @@ private MediaContent mapToMediaContent(Media media) {
635635
return new MediaContent(
636636
new MediaContent.InputAudio(fromAudioData(media.getData()), MediaContent.InputAudio.Format.WAV));
637637
}
638+
if (MimeTypeUtils.parseMimeType("application/pdf").equals(mimeType)) {
639+
return new MediaContent(new MediaContent.InputFile(media.getName(),
640+
this.fromMediaData(media.getMimeType(), media.getData())));
641+
}
638642
else {
639643
return new MediaContent(
640644
new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData())));

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,30 +1367,39 @@ public record MediaContent(// @formatter:off
13671367
@JsonProperty("type") String type,
13681368
@JsonProperty("text") String text,
13691369
@JsonProperty("image_url") ImageUrl imageUrl,
1370-
@JsonProperty("input_audio") InputAudio inputAudio) { // @formatter:on
1370+
@JsonProperty("input_audio") InputAudio inputAudio,
1371+
@JsonProperty("file") InputFile inputFile) { // @formatter:on
13711372

13721373
/**
13731374
* Shortcut constructor for a text content.
13741375
* @param text The text content of the message.
13751376
*/
13761377
public MediaContent(String text) {
1377-
this("text", text, null, null);
1378+
this("text", text, null, null, null);
13781379
}
13791380

13801381
/**
13811382
* Shortcut constructor for an image content.
13821383
* @param imageUrl The image content of the message.
13831384
*/
13841385
public MediaContent(ImageUrl imageUrl) {
1385-
this("image_url", null, imageUrl, null);
1386+
this("image_url", null, imageUrl, null, null);
13861387
}
13871388

13881389
/**
13891390
* Shortcut constructor for an audio content.
13901391
* @param inputAudio The audio content of the message.
13911392
*/
13921393
public MediaContent(InputAudio inputAudio) {
1393-
this("input_audio", null, null, inputAudio);
1394+
this("input_audio", null, null, inputAudio, null);
1395+
}
1396+
1397+
/**
1398+
* Shortcut constructor for a file content
1399+
* @param inputFile The file content of the message.
1400+
*/
1401+
public MediaContent(InputFile inputFile) {
1402+
this("file", null, null, null, inputFile);
13941403
}
13951404

13961405
/**
@@ -1428,6 +1437,18 @@ public ImageUrl(String url) {
14281437

14291438
}
14301439

1440+
/**
1441+
* Constructor for base64-encoded file
1442+
*
1443+
* @param filename name of the file
1444+
* @param fileData file data with format
1445+
* "data:{mimetype};base64,{base64-encoded-image-data}".
1446+
*/
1447+
public record InputFile(@JsonProperty("filename") String filename,
1448+
@JsonProperty("file_data") String fileData) {
1449+
1450+
}
1451+
14311452
}
14321453

14331454
/**

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.net.MalformedURLException;
2020
import java.net.URI;
21+
import java.nio.charset.StandardCharsets;
2122
import java.util.List;
2223
import java.util.Map;
2324

@@ -29,6 +30,8 @@
2930
import org.mockito.Mock;
3031
import org.mockito.Mockito;
3132
import org.mockito.junit.jupiter.MockitoExtension;
33+
import org.springframework.core.io.ByteArrayResource;
34+
import org.springframework.util.MimeType;
3235
import reactor.core.publisher.Flux;
3336

3437
import org.springframework.ai.chat.messages.SystemMessage;
@@ -126,11 +129,8 @@ public void userMessageWithMediaType() throws MalformedURLException {
126129
given(this.openAiApi.chatCompletionEntity(this.pomptCaptor.capture(), this.headersCaptor.capture()))
127130
.willReturn(Mockito.mock(ResponseEntity.class));
128131

129-
URI mediaUri = URI.create("http://test");
130-
this.chatModel.call(new Prompt(List.of(UserMessage.builder()
131-
.text("test message")
132-
.media(List.of(Media.builder().mimeType(MimeTypeUtils.IMAGE_JPEG).data(mediaUri).build()))
133-
.build())));
132+
this.chatModel
133+
.call(new Prompt(List.of(UserMessage.builder().text("test message").media(this.buildMediaList()).build())));
134134

135135
validateComplexContent(this.pomptCaptor.getValue());
136136
}
@@ -141,11 +141,10 @@ public void streamUserMessageWithMediaType() throws MalformedURLException {
141141
given(this.openAiApi.chatCompletionStream(this.pomptCaptor.capture(), this.headersCaptor.capture()))
142142
.willReturn(this.fluxResponse);
143143

144-
URI mediaUrl = URI.create("http://test");
145-
this.chatModel.stream(new Prompt(List.of(UserMessage.builder()
146-
.text("test message")
147-
.media(List.of(Media.builder().mimeType(MimeTypeUtils.IMAGE_JPEG).data(mediaUrl).build()))
148-
.build()))).subscribe();
144+
this.chatModel
145+
.stream(new Prompt(
146+
List.of(UserMessage.builder().text("test message").media(this.buildMediaList()).build())))
147+
.subscribe();
149148

150149
validateComplexContent(this.pomptCaptor.getValue());
151150
}
@@ -161,16 +160,40 @@ private void validateComplexContent(ChatCompletionRequest chatCompletionRequest)
161160
@SuppressWarnings({ "unused", "unchecked" })
162161
List<Map<String, Object>> mediaContents = (List<Map<String, Object>>) userMessage.rawContent();
163162

164-
assertThat(mediaContents).hasSize(2);
163+
assertThat(mediaContents).hasSize(3);
165164

165+
// Assert text content
166166
Map<String, Object> textContent = mediaContents.get(0);
167167
assertThat(textContent.get("type")).isEqualTo("text");
168168
assertThat(textContent.get("text")).isEqualTo("test message");
169169

170+
// Assert image content
170171
Map<String, Object> imageContent = mediaContents.get(1);
171172

172173
assertThat(imageContent.get("type")).isEqualTo("image_url");
173174
assertThat(imageContent).containsKey("image_url");
175+
176+
// Assert file content
177+
Map<String, Object> fileContent = mediaContents.get(2);
178+
assertThat(fileContent.get("type")).isEqualTo("file");
179+
assertThat(fileContent).containsKey("file");
180+
assertThat(fileContent.get("file")).isInstanceOf(Map.class);
181+
182+
Map<String, Object> fileMap = (Map<String, Object>) fileContent.get("file");
183+
assertThat(fileMap.get("file_data")).isEqualTo("data:application/pdf;base64,JVBERi0xLjc=");
184+
}
185+
186+
private List<Media> buildMediaList() {
187+
URI imageUri = URI.create("http://test");
188+
Media imageMedia = Media.builder().mimeType(MimeTypeUtils.IMAGE_JPEG).data(imageUri).build();
189+
190+
byte[] pdfData = "%PDF-1.7".getBytes(StandardCharsets.UTF_8);
191+
Media pdfMedia = Media.builder()
192+
.mimeType(MimeType.valueOf("application/pdf"))
193+
.data(new ByteArrayResource(pdfData))
194+
.build();
195+
196+
return List.of(imageMedia, pdfMedia);
174197
}
175198

176199
}

0 commit comments

Comments
 (0)