Skip to content

Commit 54364fa

Browse files
committed
Consolidate SpeechModel APIs
* Consolidate SpeechModel APIs into the spring-ai-core module, make it null-safe and covered by unit tests. * Refactor OpenAiSpeechModel APIs to implement the new consolidated APIs. * Delete leftover ImageResponseMetadata class in the spring-ai-openai module. Fixes gh-1496
1 parent cddb00a commit 54364fa

File tree

27 files changed

+690
-266
lines changed

27 files changed

+690
-266
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/ImageResponseMetadata.java

Lines changed: 0 additions & 5 deletions
This file was deleted.

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,37 @@
1919
import org.apache.commons.lang3.StringUtils;
2020
import org.slf4j.Logger;
2121
import org.slf4j.LoggerFactory;
22+
import org.springframework.ai.audio.speech.Speech;
23+
import org.springframework.ai.audio.speech.SpeechModel;
24+
import org.springframework.ai.audio.speech.SpeechOptions;
25+
import org.springframework.ai.audio.speech.SpeechPrompt;
26+
import org.springframework.ai.audio.speech.SpeechResponse;
27+
import org.springframework.ai.audio.speech.StreamingSpeechModel;
2228
import org.springframework.ai.chat.metadata.RateLimit;
29+
import org.springframework.ai.model.ModelOptionsUtils;
2330
import org.springframework.ai.openai.api.OpenAiAudioApi;
2431
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat;
25-
import org.springframework.ai.openai.audio.speech.Speech;
26-
import org.springframework.ai.openai.audio.speech.SpeechModel;
27-
import org.springframework.ai.openai.audio.speech.SpeechPrompt;
28-
import org.springframework.ai.openai.audio.speech.SpeechResponse;
29-
import org.springframework.ai.openai.audio.speech.StreamingSpeechModel;
3032
import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata;
3133
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
3234
import org.springframework.ai.retry.RetryUtils;
3335
import org.springframework.http.ResponseEntity;
36+
import org.springframework.lang.Nullable;
3437
import org.springframework.retry.support.RetryTemplate;
3538
import org.springframework.util.Assert;
3639
import reactor.core.publisher.Flux;
3740

3841
/**
39-
* OpenAI audio speech client implementation for backed by {@link OpenAiAudioApi}.
42+
* OpenAI audio speech client implementation backed by {@link OpenAiAudioApi}.
4043
*
4144
* @author Ahmed Yousri
4245
* @author Hyunjoon Choi
4346
* @author Thomas Vitale
4447
* @see OpenAiAudioApi
45-
* @since 1.0.0-M1
48+
* @since 1.0.0
4649
*/
4750
public class OpenAiAudioSpeechModel implements SpeechModel, StreamingSpeechModel {
4851

49-
private final Logger logger = LoggerFactory.getLogger(getClass());
52+
private final static Logger logger = LoggerFactory.getLogger(OpenAiAudioSpeechModel.class);
5053

5154
/**
5255
* The default options used for the audio completion requests.
@@ -114,16 +117,10 @@ public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions
114117
this.retryTemplate = retryTemplate;
115118
}
116119

117-
@Override
118-
public byte[] call(String text) {
119-
SpeechPrompt speechRequest = new SpeechPrompt(text);
120-
return call(speechRequest).getResult().getOutput();
121-
}
122-
123120
@Override
124121
public SpeechResponse call(SpeechPrompt speechPrompt) {
125-
126-
OpenAiAudioApi.SpeechRequest speechRequest = createRequest(speechPrompt);
122+
OpenAiAudioSpeechOptions requestSpeechOptions = mergeOptions(speechPrompt.getOptions(), this.defaultOptions);
123+
OpenAiAudioApi.SpeechRequest speechRequest = createRequest(speechPrompt, requestSpeechOptions);
127124

128125
ResponseEntity<byte[]> speechEntity = this.retryTemplate
129126
.execute(ctx -> this.audioApi.createSpeech(speechRequest));
@@ -149,53 +146,54 @@ public SpeechResponse call(SpeechPrompt speechPrompt) {
149146
*/
150147
@Override
151148
public Flux<SpeechResponse> stream(SpeechPrompt speechPrompt) {
149+
OpenAiAudioSpeechOptions requestSpeechOptions = mergeOptions(speechPrompt.getOptions(), this.defaultOptions);
150+
OpenAiAudioApi.SpeechRequest speechRequest = createRequest(speechPrompt, requestSpeechOptions);
152151

153-
OpenAiAudioApi.SpeechRequest speechRequest = createRequest(speechPrompt);
154-
155-
Flux<ResponseEntity<byte[]>> speechEntity = this.retryTemplate
156-
.execute(ctx -> this.audioApi.stream(speechRequest));
152+
Flux<ResponseEntity<byte[]>> speechEntity = this.audioApi.stream(speechRequest);
157153

158-
return speechEntity.map(entity -> new SpeechResponse(new Speech(entity.getBody()),
154+
return speechEntity.map(entity -> new SpeechResponse(
155+
new Speech(entity.getBody() != null ? entity.getBody() : new byte[0]),
159156
new OpenAiAudioSpeechResponseMetadata(OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity))));
160157
}
161158

162-
private OpenAiAudioApi.SpeechRequest createRequest(SpeechPrompt request) {
163-
OpenAiAudioSpeechOptions options = this.defaultOptions;
164-
165-
if (request.getOptions() != null) {
166-
if (request.getOptions() instanceof OpenAiAudioSpeechOptions runtimeOptions) {
167-
options = this.merge(runtimeOptions, options);
168-
}
169-
else {
170-
throw new IllegalArgumentException("Prompt options are not of type SpeechOptions: "
171-
+ request.getOptions().getClass().getSimpleName());
172-
}
173-
}
174-
175-
String input = StringUtils.isNotBlank(options.getInput()) ? options.getInput()
159+
private OpenAiAudioApi.SpeechRequest createRequest(SpeechPrompt request,
160+
OpenAiAudioSpeechOptions requestSpeechOptions) {
161+
String input = StringUtils.isNotBlank(requestSpeechOptions.getInput()) ? requestSpeechOptions.getInput()
176162
: request.getInstructions().getText();
177163

178164
OpenAiAudioApi.SpeechRequest.Builder requestBuilder = OpenAiAudioApi.SpeechRequest.builder()
179-
.withModel(options.getModel())
165+
.withModel(requestSpeechOptions.getModel())
180166
.withInput(input)
181-
.withVoice(options.getVoice())
182-
.withResponseFormat(options.getResponseFormat())
183-
.withSpeed(options.getSpeed());
167+
.withResponseFormat(requestSpeechOptions.getResponseFormat())
168+
.withSpeed(requestSpeechOptions.getSpeed())
169+
.withVoice(requestSpeechOptions.getVoice());
184170

185171
return requestBuilder.build();
186172
}
187173

188-
private OpenAiAudioSpeechOptions merge(OpenAiAudioSpeechOptions source, OpenAiAudioSpeechOptions target) {
189-
OpenAiAudioSpeechOptions.Builder mergedBuilder = OpenAiAudioSpeechOptions.builder();
174+
/**
175+
* Merge runtime and default {@link SpeechOptions} to compute the final options to use
176+
* in the request.
177+
*/
178+
private OpenAiAudioSpeechOptions mergeOptions(@Nullable SpeechOptions runtimeOptions,
179+
OpenAiAudioSpeechOptions defaultOptions) {
180+
var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, SpeechOptions.class,
181+
OpenAiAudioSpeechOptions.class);
190182

191-
mergedBuilder.withModel(source.getModel() != null ? source.getModel() : target.getModel());
192-
mergedBuilder.withInput(source.getInput() != null ? source.getInput() : target.getInput());
193-
mergedBuilder.withVoice(source.getVoice() != null ? source.getVoice() : target.getVoice());
194-
mergedBuilder.withResponseFormat(
195-
source.getResponseFormat() != null ? source.getResponseFormat() : target.getResponseFormat());
196-
mergedBuilder.withSpeed(source.getSpeed() != null ? source.getSpeed() : target.getSpeed());
183+
if (runtimeOptionsForProvider == null) {
184+
return defaultOptions;
185+
}
197186

198-
return mergedBuilder.build();
187+
return OpenAiAudioSpeechOptions.builder()
188+
// Handle portable options
189+
.withModel(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
190+
// Handle OpenAI specific options
191+
.withInput(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getInput(), defaultOptions.getInput()))
192+
.withResponseFormat(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getResponseFormat(),
193+
defaultOptions.getResponseFormat()))
194+
.withSpeed(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getSpeed(), defaultOptions.getSpeed()))
195+
.withVoice(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getVoice(), defaultOptions.getVoice()))
196+
.build();
199197
}
200198

201199
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechOptions.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import com.fasterxml.jackson.annotation.JsonInclude;
2020
import com.fasterxml.jackson.annotation.JsonProperty;
21-
import org.springframework.ai.model.ModelOptions;
21+
import org.springframework.ai.audio.speech.SpeechOptions;
2222
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat;
2323
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.Voice;
2424

@@ -27,10 +27,11 @@
2727
*
2828
* @author Ahmed Yousri
2929
* @author Hyunjoon Choi
30-
* @since 1.0.0-M1
30+
* @author Thomas Vitale
31+
* @since 1.0.0
3132
*/
3233
@JsonInclude(JsonInclude.Include.NON_NULL)
33-
public class OpenAiAudioSpeechOptions implements ModelOptions {
34+
public class OpenAiAudioSpeechOptions implements SpeechOptions {
3435

3536
/**
3637
* ID of the model to use for generating the audio. One of the available TTS models:
@@ -105,6 +106,7 @@ public OpenAiAudioSpeechOptions build() {
105106

106107
}
107108

109+
@Override
108110
public String getModel() {
109111
return model;
110112
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/audio/speech/SpeechPrompt.java

Lines changed: 0 additions & 81 deletions
This file was deleted.

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/audio/OpenAiAudioSpeechResponseMetadata.java

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,22 @@
1616

1717
package org.springframework.ai.openai.metadata.audio;
1818

19+
import org.springframework.ai.audio.speech.SpeechResponseMetadata;
1920
import org.springframework.ai.chat.metadata.EmptyRateLimit;
2021
import org.springframework.ai.chat.metadata.RateLimit;
21-
import org.springframework.ai.model.MutableResponseMetadata;
22-
import org.springframework.ai.model.ResponseMetadata;
23-
import org.springframework.ai.openai.api.OpenAiAudioApi;
2422
import org.springframework.lang.Nullable;
25-
import org.springframework.util.Assert;
26-
27-
import java.util.HashMap;
2823

2924
/**
3025
* Audio speech metadata implementation for {@literal OpenAI}.
3126
*
3227
* @author Ahmed Yousri
28+
* @author Thomas Vitale
3329
* @see RateLimit
3430
*/
35-
public class OpenAiAudioSpeechResponseMetadata extends MutableResponseMetadata {
31+
public class OpenAiAudioSpeechResponseMetadata extends SpeechResponseMetadata {
3632

3733
protected static final String AI_METADATA_STRING = "{ @type: %1$s, requestsLimit: %2$s }";
3834

39-
public static final OpenAiAudioSpeechResponseMetadata NULL = new OpenAiAudioSpeechResponseMetadata() {
40-
};
41-
42-
public static OpenAiAudioSpeechResponseMetadata from(OpenAiAudioApi.StructuredResponse result) {
43-
Assert.notNull(result, "OpenAI speech must not be null");
44-
OpenAiAudioSpeechResponseMetadata speechResponseMetadata = new OpenAiAudioSpeechResponseMetadata();
45-
return speechResponseMetadata;
46-
}
47-
48-
public static OpenAiAudioSpeechResponseMetadata from(String result) {
49-
Assert.notNull(result, "OpenAI speech must not be null");
50-
OpenAiAudioSpeechResponseMetadata speechResponseMetadata = new OpenAiAudioSpeechResponseMetadata();
51-
return speechResponseMetadata;
52-
}
53-
5435
@Nullable
5536
private RateLimit rateLimit;
5637

@@ -62,17 +43,11 @@ public OpenAiAudioSpeechResponseMetadata(@Nullable RateLimit rateLimit) {
6243
this.rateLimit = rateLimit;
6344
}
6445

65-
@Nullable
6646
public RateLimit getRateLimit() {
6747
RateLimit rateLimit = this.rateLimit;
6848
return rateLimit != null ? rateLimit : new EmptyRateLimit();
6949
}
7050

71-
public OpenAiAudioSpeechResponseMetadata withRateLimit(RateLimit rateLimit) {
72-
this.rateLimit = rateLimit;
73-
return this;
74-
}
75-
7651
@Override
7752
public String toString() {
7853
return AI_METADATA_STRING.formatted(getClass().getName(), getRateLimit());
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
@NonNullApi
18+
@NonNullFields
19+
package org.springframework.ai.openai.metadata.audio;
20+
21+
import org.springframework.lang.NonNullApi;
22+
import org.springframework.lang.NonNullFields;
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
@NonNullApi
18+
@NonNullFields
19+
package org.springframework.ai.openai;
20+
21+
import org.springframework.lang.NonNullApi;
22+
import org.springframework.lang.NonNullFields;

0 commit comments

Comments
 (0)