Skip to content

Commit 33f431e

Browse files
ThomasVitaletzolov
authored andcommitted
feat(openai) - Support audio input and output modality
- Added support for audio input and output in OpenAI Chat Completion API - Introduced new audio-related parameters, enums, and record types - Updated ChatCompletionMessage, ChatCompletionChunk, and related classes - Added new AudioParameters, AudioOutput, and InputAudio record types - Implemented method to handle audio media content conversion - Included new model enum for GPT-4o audio preview - Extended existing API classes to accommodate audio modalities - Modified usage tracking and metadata classes to handle audio-specific token details - Improved ModelOptionsUtils with additional JSON utility methods Tests: - Updated test classes to validate audio input and output functionality - Added integration tests for multimodal audio input with streaming and non-streaming methods - Created parameterized tests for audio-enabled models - Enhanced OpenAI API integration tests to cover audio-related scenarios Docs: - Updated documentation in spring-ai-docs to explain audio multimodality support Resolves #1560
1 parent be0f9fb commit 33f431e

File tree

13 files changed

+403
-72
lines changed

13 files changed

+403
-72
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.springframework.ai.chat.prompt.ChatOptions;
5656
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
5757
import org.springframework.ai.chat.prompt.Prompt;
58+
import org.springframework.ai.model.Media;
5859
import org.springframework.ai.model.ModelOptionsUtils;
5960
import org.springframework.ai.model.function.FunctionCallback;
6061
import org.springframework.ai.model.function.FunctionCallbackResolver;
@@ -76,6 +77,7 @@
7677
import org.springframework.util.Assert;
7778
import org.springframework.util.CollectionUtils;
7879
import org.springframework.util.MimeType;
80+
import org.springframework.util.MimeTypeUtils;
7981
import org.springframework.util.MultiValueMap;
8082
import org.springframework.util.StringUtils;
8183

@@ -406,7 +408,7 @@ private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionC
406408
chunkChoice.logprobs()))
407409
.toList();
408410

409-
return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(),
411+
return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.serviceTier(),
410412
chunk.systemFingerprint(), "chat.completion", chunk.usage());
411413
}
412414

@@ -423,11 +425,7 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
423425
List<MediaContent> contentList = new ArrayList<>(
424426
List.of(new MediaContent(message.getContent())));
425427

426-
contentList.addAll(userMessage.getMedia()
427-
.stream()
428-
.map(media -> new MediaContent(new MediaContent.ImageUrl(
429-
this.fromMediaData(media.getMimeType(), media.getData()))))
430-
.toList());
428+
contentList.addAll(userMessage.getMedia().stream().map(this::mapToMediaContent).toList());
431429

432430
content = contentList;
433431
}
@@ -446,7 +444,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
446444
}).toList();
447445
}
448446
return List.of(new ChatCompletionMessage(assistantMessage.getContent(),
449-
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null));
447+
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, null));
450448
}
451449
else if (message.getMessageType() == MessageType.TOOL) {
452450
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
@@ -456,7 +454,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
456454
return toolMessage.getResponses()
457455
.stream()
458456
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
459-
tr.id(), null, null))
457+
tr.id(), null, null, null))
460458
.toList();
461459
}
462460
else {
@@ -508,6 +506,29 @@ else if (prompt.getOptions() instanceof OpenAiChatOptions) {
508506
return request;
509507
}
510508

509+
private MediaContent mapToMediaContent(Media media) {
510+
var mimeType = media.getMimeType();
511+
if (MimeTypeUtils.parseMimeType("audio/mp3").equals(mimeType)) {
512+
return new MediaContent(
513+
new MediaContent.InputAudio(fromAudioData(media.getData()), MediaContent.InputAudio.Format.MP3));
514+
}
515+
if (MimeTypeUtils.parseMimeType("audio/wav").equals(mimeType)) {
516+
return new MediaContent(
517+
new MediaContent.InputAudio(fromAudioData(media.getData()), MediaContent.InputAudio.Format.WAV));
518+
}
519+
else {
520+
return new MediaContent(
521+
new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData())));
522+
}
523+
}
524+
525+
private String fromAudioData(Object audioData) {
526+
if (audioData instanceof byte[] bytes) {
527+
return Base64.getEncoder().encodeToString(bytes);
528+
}
529+
throw new IllegalArgumentException("Unsupported audio data type: " + audioData.getClass().getSimpleName());
530+
}
531+
511532
private String fromMediaData(MimeType mimeType, Object mediaContentData) {
512533
if (mediaContentData instanceof byte[] bytes) {
513534
// Assume the bytes are an image. So, convert the bytes to a base64 encoded

0 commit comments

Comments
 (0)