Skip to content

Commit bdb66e5

Browse files
committed
OpenAI - Support audio input modality
* Extend OpenAiApi to support the latest version of the Chat Completion API, including input and output audio modality. * Support input audio modality in OpenAiChatModel via the existing multimodality support in Spring AI. Fixes gh-1560
1 parent 8eef6e6 commit bdb66e5

File tree

11 files changed

+368
-59
lines changed

11 files changed

+368
-59
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.springframework.ai.chat.prompt.ChatOptions;
4949
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
5050
import org.springframework.ai.chat.prompt.Prompt;
51+
import org.springframework.ai.model.Media;
5152
import org.springframework.ai.model.ModelOptionsUtils;
5253
import org.springframework.ai.model.function.FunctionCallback;
5354
import org.springframework.ai.model.function.FunctionCallbackContext;
@@ -69,6 +70,7 @@
6970
import org.springframework.util.Assert;
7071
import org.springframework.util.CollectionUtils;
7172
import org.springframework.util.MimeType;
73+
import org.springframework.util.MimeTypeUtils;
7274
import org.springframework.util.MultiValueMap;
7375
import org.springframework.util.StringUtils;
7476

@@ -408,7 +410,7 @@ private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionC
408410
chunkChoice.logprobs()))
409411
.toList();
410412

411-
return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(),
413+
return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.serviceTier(),
412414
chunk.systemFingerprint(), "chat.completion", chunk.usage());
413415
}
414416

@@ -425,11 +427,7 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
425427
List<MediaContent> contentList = new ArrayList<>(
426428
List.of(new MediaContent(message.getContent())));
427429

428-
contentList.addAll(userMessage.getMedia()
429-
.stream()
430-
.map(media -> new MediaContent(new MediaContent.ImageUrl(
431-
this.fromMediaData(media.getMimeType(), media.getData()))))
432-
.toList());
430+
contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
433431

434432
content = contentList;
435433
}
@@ -448,7 +446,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
448446
}).toList();
449447
}
450448
return List.of(new ChatCompletionMessage(assistantMessage.getContent(),
451-
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null));
449+
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, null));
452450
}
453451
else if (message.getMessageType() == MessageType.TOOL) {
454452
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
@@ -460,7 +458,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
460458
return toolMessage.getResponses()
461459
.stream()
462460
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
463-
tr.id(), null, null))
461+
tr.id(), null, null, null))
464462
.toList();
465463
}
466464
else {
@@ -512,6 +510,29 @@ else if (prompt.getOptions() instanceof OpenAiChatOptions) {
512510
return request;
513511
}
514512

513+
private MediaContent mapToMediaContent(Media media) {
514+
var mimeType = media.getMimeType();
515+
if (MimeTypeUtils.parseMimeType("audio/mp3").equals(mimeType)) {
516+
return new MediaContent(
517+
new MediaContent.InputAudio(fromAudioData(media.getData()), MediaContent.InputAudio.Format.MP3));
518+
}
519+
if (MimeTypeUtils.parseMimeType("audio/wav").equals(mimeType)) {
520+
return new MediaContent(
521+
new MediaContent.InputAudio(fromAudioData(media.getData()), MediaContent.InputAudio.Format.WAV));
522+
}
523+
else {
524+
return new MediaContent(
525+
new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData())));
526+
}
527+
}
528+
529+
private String fromAudioData(Object audioData) {
530+
if (audioData instanceof byte[] bytes) {
531+
return Base64.getEncoder().encodeToString(bytes);
532+
}
533+
throw new IllegalArgumentException("Unsupported audio data type: " + audioData.getClass().getSimpleName());
534+
}
535+
515536
private String fromMediaData(MimeType mimeType, Object mediaContentData) {
516537
if (mediaContentData instanceof byte[] bytes) {
517538
// Assume the bytes are an image. So, convert the bytes to a base64 encoded

0 commit comments

Comments
 (0)