Skip to content

Commit df07fca

Browse files
committed
feat(openai) - Support for audio output in OpenAI chat model
- Introduced new options for audio output modalities in ChatCompletionRequest - Added AudioParameters configuration for voice and audio format selection - Enhanced OpenAiChatModel to handle audio generation and embedding - Updated AssistantMessage and Media classes to support audio media - Added integration tests for audio output functionality - Implemented support for text and audio multi-modal responses - Updated Spring AI's chat model comparison table to clarify OpenAI's input/output modalities - Added new configuration properties for audio output: * spring.ai.openai.chat.options.modalities * spring.ai.openai.chat.options.audio-parameters - Extended documentation to explain audio output generation with the gpt-4o-audio-preview model - Updated Spring Boot configuration metadata to support new audio-related properties - Included auto-configuration integration test for chat model with audio response generation Resolves #1841
1 parent 20ccbca commit df07fca

File tree

14 files changed

+407
-47
lines changed

14 files changed

+407
-47
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
6565
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice;
6666
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
67+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.AudioOutput;
6768
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ChatCompletionFunction;
6869
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.MediaContent;
6970
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall;
@@ -72,6 +73,8 @@
7273
import org.springframework.ai.openai.metadata.OpenAiUsage;
7374
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
7475
import org.springframework.ai.retry.RetryUtils;
76+
import org.springframework.core.io.ByteArrayResource;
77+
import org.springframework.core.io.Resource;
7578
import org.springframework.http.ResponseEntity;
7679
import org.springframework.retry.support.RetryTemplate;
7780
import org.springframework.util.Assert;
@@ -251,7 +254,7 @@ public ChatResponse call(Prompt prompt) {
251254
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
252255
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
253256
// @formatter:on
254-
return buildGeneration(choice, metadata);
257+
return buildGeneration(choice, metadata, request);
255258
}).toList();
256259

257260
// Non function calling.
@@ -282,6 +285,17 @@ public Flux<ChatResponse> stream(Prompt prompt) {
282285
return Flux.deferContextual(contextView -> {
283286
ChatCompletionRequest request = createRequest(prompt, true);
284287

288+
if (request.outputModalities() != null) {
289+
if (request.outputModalities().stream().anyMatch(m -> m.equals("audio"))) {
290+
logger.warn("Audio output is not supported for streaming requests. Removing audio output.");
291+
throw new IllegalArgumentException("Audio output is not supported for streaming requests.");
292+
}
293+
}
294+
if (request.audioParameters() != null) {
295+
logger.warn("Audio parameters are not supported for streaming requests. Removing audio parameters.");
296+
throw new IllegalArgumentException("Audio parameters are not supported for streaming requests.");
297+
}
298+
285299
Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request,
286300
getAdditionalHttpHeaders(prompt));
287301

@@ -320,7 +334,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
320334
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
321335
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
322336

323-
return buildGeneration(choice, metadata);
337+
return buildGeneration(choice, metadata, request);
324338
}).toList();
325339
// @formatter:on
326340

@@ -367,7 +381,7 @@ private MultiValueMap<String, String> getAdditionalHttpHeaders(Prompt prompt) {
367381
headers.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> List.of(e.getValue()))));
368382
}
369383

370-
private Generation buildGeneration(Choice choice, Map<String, Object> metadata) {
384+
private Generation buildGeneration(Choice choice, Map<String, Object> metadata, ChatCompletionRequest request) {
371385
List<AssistantMessage.ToolCall> toolCalls = choice.message().toolCalls() == null ? List.of()
372386
: choice.message()
373387
.toolCalls()
@@ -376,10 +390,26 @@ private Generation buildGeneration(Choice choice, Map<String, Object> metadata)
376390
toolCall.function().name(), toolCall.function().arguments()))
377391
.toList();
378392

379-
var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls);
380393
String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : "");
381-
var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build();
382-
return new Generation(assistantMessage, generationMetadata);
394+
var generationMetadataBuilder = ChatGenerationMetadata.builder().finishReason(finishReason);
395+
396+
List<Media> media = new ArrayList<>();
397+
String textContent = choice.message().content();
398+
var audioOutput = choice.message().audioOutput();
399+
if (audioOutput != null) {
400+
String mimeType = String.format("audio/%s", request.audioParameters().format().name().toLowerCase());
401+
byte[] audioData = Base64.getDecoder().decode(audioOutput.data());
402+
Resource resource = new ByteArrayResource(audioData);
403+
media.add(new Media(MimeTypeUtils.parseMimeType(mimeType), resource, audioOutput.id()));
404+
if (!StringUtils.hasText(textContent)) {
405+
textContent = audioOutput.transcript();
406+
}
407+
generationMetadataBuilder.metadata("audioId", audioOutput.id());
408+
generationMetadataBuilder.metadata("audioExpiresAt", audioOutput.expiresAt());
409+
}
410+
411+
var assistantMessage = new AssistantMessage(textContent, metadata, toolCalls, media);
412+
return new Generation(assistantMessage, generationMetadataBuilder.build());
383413
}
384414

385415
private ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit) {
@@ -443,8 +473,15 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
443473
return new ToolCall(toolCall.id(), toolCall.type(), function);
444474
}).toList();
445475
}
476+
AudioOutput audioOutput = null;
477+
if (!CollectionUtils.isEmpty(assistantMessage.getMedia())) {
478+
Assert.isTrue(assistantMessage.getMedia().size() == 1,
479+
"Only one media content is supported for assistant messages");
480+
audioOutput = new AudioOutput(assistantMessage.getMedia().get(0).getId(), null, null, null);
481+
482+
}
446483
return List.of(new ChatCompletionMessage(assistantMessage.getContent(),
447-
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, null));
484+
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput));
448485
}
449486
else if (message.getMessageType() == MessageType.TOOL) {
450487
ToolResponseMessage toolMessage = (ToolResponseMessage) message;

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.ai.model.function.FunctionCallback;
3434
import org.springframework.ai.model.function.FunctionCallingOptions;
3535
import org.springframework.ai.openai.api.OpenAiApi;
36+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters;
3637
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions;
3738
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder;
3839
import org.springframework.ai.openai.api.ResponseFormat;
@@ -92,6 +93,27 @@ public class OpenAiChatOptions implements FunctionCallingOptions {
9293
* on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
9394
*/
9495
private @JsonProperty("n") Integer n;
96+
97+
/**
98+
* Output types that you would like the model to generate for this request.
99+
* Most models are capable of generating text, which is the default.
100+
* The gpt-4o-audio-preview model can also be used to generate audio.
101+
* To request that this model generate both text and audio responses,
102+
* you can use: ["text", "audio"].
103+
* Note that the audio modality is only available for the gpt-4o-audio-preview model
104+
* and is not supported for streaming completions.
105+
*/
106+
private @JsonProperty("modalities") List<String> modalities;
107+
108+
/**
109+
* Audio parameters for the audio generation. Required when audio output is requested with
110+
* modalities: ["audio"]
111+
* Note: that the audio modality is only available for the gpt-4o-audio-preview model
112+
* and is not supported for streaming completions.
113+
114+
*/
115+
private @JsonProperty("audio") AudioParameters audio;
116+
95117
/**
96118
* Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they
97119
* appear in the text so far, increasing the model's likelihood to talk about new topics.
@@ -206,6 +228,8 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
206228
.withMaxTokens(fromOptions.getMaxTokens())
207229
.withMaxCompletionTokens(fromOptions.getMaxCompletionTokens())
208230
.withN(fromOptions.getN())
231+
.withModalities(fromOptions.getModalities())
232+
.withAudio(fromOptions.getAudio())
209233
.withPresencePenalty(fromOptions.getPresencePenalty())
210234
.withResponseFormat(fromOptions.getResponseFormat())
211235
.withStreamUsage(fromOptions.getStreamUsage())
@@ -300,6 +324,22 @@ public void setN(Integer n) {
300324
this.n = n;
301325
}
302326

327+
public List<String> getModalities() {
328+
return modalities;
329+
}
330+
331+
public void setModalities(List<String> modalities) {
332+
this.modalities = modalities;
333+
}
334+
335+
public AudioParameters getAudio() {
336+
return audio;
337+
}
338+
339+
public void setAudio(AudioParameters audio) {
340+
this.audio = audio;
341+
}
342+
303343
@Override
304344
public Double getPresencePenalty() {
305345
return this.presencePenalty;
@@ -465,7 +505,7 @@ public int hashCode() {
465505
this.maxTokens, this.maxCompletionTokens, this.n, this.presencePenalty, this.responseFormat,
466506
this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
467507
this.user, this.parallelToolCalls, this.functionCallbacks, this.functions, this.httpHeaders,
468-
this.proxyToolCalls, this.toolContext);
508+
this.proxyToolCalls, this.toolContext, this.modalities, this.audio);
469509
}
470510

471511
@Override
@@ -493,7 +533,8 @@ public boolean equals(Object o) {
493533
&& Objects.equals(this.functions, other.functions)
494534
&& Objects.equals(this.httpHeaders, other.httpHeaders)
495535
&& Objects.equals(this.toolContext, other.toolContext)
496-
&& Objects.equals(this.proxyToolCalls, other.proxyToolCalls);
536+
&& Objects.equals(this.proxyToolCalls, other.proxyToolCalls)
537+
&& Objects.equals(this.modalities, other.modalities) && Objects.equals(this.audio, other.audio);
497538
}
498539

499540
@Override
@@ -558,6 +599,16 @@ public Builder withN(Integer n) {
558599
return this;
559600
}
560601

602+
public Builder withModalities(List<String> modalities) {
603+
this.options.modalities = modalities;
604+
return this;
605+
}
606+
607+
public Builder withAudio(AudioParameters audio) {
608+
this.options.audio = audio;
609+
return this;
610+
}
611+
561612
public Builder withPresencePenalty(Double presencePenalty) {
562613
this.options.presencePenalty = presencePenalty;
563614
return this;

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -839,10 +839,10 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
839839
* @param model ID of the model to use.
840840
* @param audio Parameters for audio output. Required when audio output is requested with outputModalities: ["audio"].
841841
*/
842-
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, AudioParameters audio) {
842+
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, AudioParameters audio, boolean stream) {
843843
this(messages, model, null, null, null, null, null, null,
844844
null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null,
845-
null, null, null, false, null, null, null,
845+
null, null, null, stream, null, null, null,
846846
null, null, null, null);
847847
}
848848

@@ -938,34 +938,34 @@ public record AudioParameters(
938938
* Specifies the voice type.
939939
*/
940940
public enum Voice {
941-
@JsonProperty("alloy")
942-
ALLOY,
943-
@JsonProperty("echo")
944-
ECHO,
945-
@JsonProperty("fable")
946-
FABLE,
947-
@JsonProperty("onyx")
948-
ONYX,
949-
@JsonProperty("nova")
950-
NOVA,
951-
@JsonProperty("shimmer")
952-
SHIMMER
941+
/** Alloy voice */
942+
@JsonProperty("alloy") ALLOY,
943+
/** Echo voice */
944+
@JsonProperty("echo") ECHO,
945+
/** Fable voice */
946+
@JsonProperty("fable") FABLE,
947+
/** Onyx voice */
948+
@JsonProperty("onyx") ONYX,
949+
/** Nova voice */
950+
@JsonProperty("nova") NOVA,
951+
/** Shimmer voice */
952+
@JsonProperty("shimmer") SHIMMER
953953
}
954954

955955
/**
956956
* Specifies the output audio format.
957957
*/
958958
public enum AudioResponseFormat {
959-
@JsonProperty("mp3")
960-
MP3,
961-
@JsonProperty("flac")
962-
FLAC,
963-
@JsonProperty("opus")
964-
OPUS,
965-
@JsonProperty("pcm16")
966-
PCM16,
967-
@JsonProperty("wav")
968-
WAV
959+
/** MP3 format */
960+
@JsonProperty("mp3") MP3,
961+
/** FLAC format */
962+
@JsonProperty("flac") FLAC,
963+
/** OPUS format */
964+
@JsonProperty("opus") OPUS,
965+
/** PCM16 format */
966+
@JsonProperty("pcm16") PCM16,
967+
/** WAV format */
968+
@JsonProperty("wav") WAV
969969
}
970970
}
971971

@@ -1119,10 +1119,10 @@ public record InputAudio(// @formatter:off
11191119
@JsonProperty("format") Format format) {
11201120

11211121
public enum Format {
1122-
@JsonProperty("mp3")
1123-
MP3,
1124-
@JsonProperty("wav")
1125-
WAV
1122+
/** MP3 audio format */
1123+
@JsonProperty("mp3") MP3,
1124+
/** WAV audio format */
1125+
@JsonProperty("wav") WAV
11261126
} // @formatter:on
11271127
}
11281128

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.springframework.http.ResponseEntity;
3636

3737
import static org.assertj.core.api.Assertions.assertThat;
38+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
3839

3940
/**
4041
* @author Christian Tzolov
@@ -105,7 +106,7 @@ void outputAudio() {
105106
ChatCompletionRequest.AudioParameters.Voice.NOVA,
106107
ChatCompletionRequest.AudioParameters.AudioResponseFormat.MP3);
107108
ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(List.of(chatCompletionMessage),
108-
OpenAiApi.ChatModel.GPT_4_O_AUDIO_PREVIEW.getValue(), audioParameters);
109+
OpenAiApi.ChatModel.GPT_4_O_AUDIO_PREVIEW.getValue(), audioParameters, false);
109110
ResponseEntity<ChatCompletion> response = this.openAiApi.chatCompletionEntity(chatCompletionRequest);
110111

111112
assertThat(response).isNotNull();
@@ -119,4 +120,23 @@ void outputAudio() {
119120
.containsIgnoringCase("leviosa");
120121
}
121122

123+
@Test
124+
void streamOutputAudio() {
125+
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage(
126+
"What is the magic spell to make objects fly?", Role.USER);
127+
ChatCompletionRequest.AudioParameters audioParameters = new ChatCompletionRequest.AudioParameters(
128+
ChatCompletionRequest.AudioParameters.Voice.NOVA,
129+
ChatCompletionRequest.AudioParameters.AudioResponseFormat.MP3);
130+
ChatCompletionRequest chatCompletionRequest = new ChatCompletionRequest(List.of(chatCompletionMessage),
131+
OpenAiApi.ChatModel.GPT_4_O_AUDIO_PREVIEW.getValue(), audioParameters, true);
132+
// Flux<ChatCompletionChunk> response =
133+
// this.openAiApi.chatCompletionStream(chatCompletionRequest);
134+
135+
// var responseList = response.collectList().block();
136+
137+
assertThatThrownBy(() -> this.openAiApi.chatCompletionStream(chatCompletionRequest).collectList().block())
138+
.isInstanceOf(RuntimeException.class)
139+
.hasMessageContaining("400 Bad Request from POST https://api.openai.com/v1/chat/completions");
140+
}
141+
122142
}

0 commit comments

Comments
 (0)