Skip to content

Commit 6c88654

Browse files
committed
Consolidate retry config for OpenAI
Signed-off-by: Thomas Vitale <[email protected]>
1 parent 6270d62 commit 6c88654

File tree

7 files changed

+146
-160
lines changed

7 files changed

+146
-160
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,25 @@
2222
import org.springframework.ai.chat.metadata.RateLimit;
2323
import org.springframework.ai.openai.api.OpenAiAudioApi;
2424
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat;
25-
import org.springframework.ai.openai.api.common.OpenAiApiException;
2625
import org.springframework.ai.openai.audio.speech.Speech;
2726
import org.springframework.ai.openai.audio.speech.SpeechModel;
2827
import org.springframework.ai.openai.audio.speech.SpeechPrompt;
2928
import org.springframework.ai.openai.audio.speech.SpeechResponse;
3029
import org.springframework.ai.openai.audio.speech.StreamingSpeechModel;
3130
import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata;
3231
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
32+
import org.springframework.ai.retry.RetryUtils;
3333
import org.springframework.http.ResponseEntity;
3434
import org.springframework.retry.support.RetryTemplate;
3535
import org.springframework.util.Assert;
3636
import reactor.core.publisher.Flux;
3737

38-
import java.time.Duration;
39-
4038
/**
4139
* OpenAI audio speech client implementation for backed by {@link OpenAiAudioApi}.
4240
*
4341
* @author Ahmed Yousri
4442
* @author Hyunjoon Choi
43+
* @author Thomas Vitale
4544
* @see OpenAiAudioApi
4645
* @since 1.0.0-M1
4746
*/
@@ -63,11 +62,7 @@ public class OpenAiAudioSpeechModel implements SpeechModel, StreamingSpeechModel
6362
/**
6463
* The retry template used to retry the OpenAI Audio API calls.
6564
*/
66-
public final RetryTemplate retryTemplate = RetryTemplate.builder()
67-
.maxAttempts(10)
68-
.retryOn(OpenAiApiException.class)
69-
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
70-
.build();
65+
private final RetryTemplate retryTemplate;
7166

7267
/**
7368
* Low-level access to the OpenAI Audio API.
@@ -98,10 +93,25 @@ public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi) {
9893
* options.
9994
*/
10095
public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions options) {
96+
this(audioApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
97+
}
98+
99+
/**
100+
* Initializes a new instance of the OpenAiAudioSpeechModel class with the provided
101+
* OpenAiAudioApi and options.
102+
* @param audioApi The OpenAiAudioApi to use for speech synthesis.
103+
* @param options The OpenAiAudioSpeechOptions containing the speech synthesis
104+
* options.
105+
* @param retryTemplate The retry template.
106+
*/
107+
public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions options,
108+
RetryTemplate retryTemplate) {
101109
Assert.notNull(audioApi, "OpenAiAudioApi must not be null");
102110
Assert.notNull(options, "OpenAiSpeechOptions must not be null");
111+
Assert.notNull(options, "RetryTemplate must not be null");
103112
this.audioApi = audioApi;
104113
this.defaultOptions = options;
114+
this.retryTemplate = retryTemplate;
105115
}
106116

107117
@Override
@@ -113,40 +123,43 @@ public byte[] call(String text) {
113123
@Override
114124
public SpeechResponse call(SpeechPrompt speechPrompt) {
115125

116-
return this.retryTemplate.execute(ctx -> {
126+
OpenAiAudioApi.SpeechRequest speechRequest = createRequest(speechPrompt);
117127

118-
OpenAiAudioApi.SpeechRequest speechRequest = createRequestBody(speechPrompt);
128+
ResponseEntity<byte[]> speechEntity = this.retryTemplate
129+
.execute(ctx -> this.audioApi.createSpeech(speechRequest));
119130

120-
ResponseEntity<byte[]> speechEntity = this.audioApi.createSpeech(speechRequest);
121-
var speech = speechEntity.getBody();
122-
123-
if (speech == null) {
124-
logger.warn("No speech response returned for speechRequest: {}", speechRequest);
125-
return new SpeechResponse(new Speech(new byte[0]));
126-
}
131+
var speech = speechEntity.getBody();
127132

128-
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(speechEntity);
133+
if (speech == null) {
134+
logger.warn("No speech response returned for speechRequest: {}", speechRequest);
135+
return new SpeechResponse(new Speech(new byte[0]));
136+
}
129137

130-
return new SpeechResponse(new Speech(speech), new OpenAiAudioSpeechResponseMetadata(rateLimits));
138+
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(speechEntity);
131139

132-
});
140+
return new SpeechResponse(new Speech(speech), new OpenAiAudioSpeechResponseMetadata(rateLimits));
133141
}
134142

135143
/**
136144
* Streams the audio response for the given speech prompt.
137-
* @param prompt The speech prompt containing the text and options for speech
145+
* @param speechPrompt The speech prompt containing the text and options for speech
138146
* synthesis.
139147
* @return A Flux of SpeechResponse objects containing the streamed audio and
140148
* metadata.
141149
*/
142150
@Override
143-
public Flux<SpeechResponse> stream(SpeechPrompt prompt) {
144-
return this.audioApi.stream(this.createRequestBody(prompt))
145-
.map(entity -> new SpeechResponse(new Speech(entity.getBody()), new OpenAiAudioSpeechResponseMetadata(
146-
OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity))));
151+
public Flux<SpeechResponse> stream(SpeechPrompt speechPrompt) {
152+
153+
OpenAiAudioApi.SpeechRequest speechRequest = createRequest(speechPrompt);
154+
155+
Flux<ResponseEntity<byte[]>> speechEntity = this.retryTemplate
156+
.execute(ctx -> this.audioApi.stream(speechRequest));
157+
158+
return speechEntity.map(entity -> new SpeechResponse(new Speech(entity.getBody()),
159+
new OpenAiAudioSpeechResponseMetadata(OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity))));
147160
}
148161

149-
private OpenAiAudioApi.SpeechRequest createRequestBody(SpeechPrompt request) {
162+
private OpenAiAudioApi.SpeechRequest createRequest(SpeechPrompt request) {
150163
OpenAiAudioSpeechOptions options = this.defaultOptions;
151164

152165
if (request.getOptions() != null) {

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
*
5757
* @author Michael Lavelle
5858
* @author Christian Tzolov
59+
* @author Thomas Vitale
5960
* @see OpenAiAudioApi
6061
* @since 0.8.1
6162
*/
@@ -65,7 +66,7 @@ public class OpenAiAudioTranscriptionModel implements Model<AudioTranscriptionPr
6566

6667
private final OpenAiAudioTranscriptionOptions defaultOptions;
6768

68-
public final RetryTemplate retryTemplate;
69+
private final RetryTemplate retryTemplate;
6970

7071
private final OpenAiAudioApi audioApi;
7172

@@ -80,8 +81,7 @@ public OpenAiAudioTranscriptionModel(OpenAiAudioApi audioApi) {
8081
.withModel(OpenAiAudioApi.WhisperModel.WHISPER_1.getValue())
8182
.withResponseFormat(OpenAiAudioApi.TranscriptResponseFormat.JSON)
8283
.withTemperature(0.7f)
83-
.build(),
84-
RetryUtils.DEFAULT_RETRY_TEMPLATE);
84+
.build());
8585
}
8686

8787
/**
@@ -119,74 +119,71 @@ public String call(Resource audioResource) {
119119
}
120120

121121
@Override
122-
public AudioTranscriptionResponse call(AudioTranscriptionPrompt request) {
122+
public AudioTranscriptionResponse call(AudioTranscriptionPrompt transcriptionPrompt) {
123123

124-
return this.retryTemplate.execute(ctx -> {
124+
Resource audioResource = transcriptionPrompt.getInstructions();
125125

126-
Resource audioResource = request.getInstructions();
126+
OpenAiAudioApi.TranscriptionRequest request = createRequest(transcriptionPrompt);
127127

128-
OpenAiAudioApi.TranscriptionRequest requestBody = createRequestBody(request);
128+
if (request.responseFormat().isJsonType()) {
129129

130-
if (requestBody.responseFormat().isJsonType()) {
130+
ResponseEntity<StructuredResponse> transcriptionEntity = this.retryTemplate
131+
.execute(ctx -> this.audioApi.createTranscription(request, StructuredResponse.class));
131132

132-
ResponseEntity<StructuredResponse> transcriptionEntity = this.audioApi.createTranscription(requestBody,
133-
StructuredResponse.class);
133+
var transcription = transcriptionEntity.getBody();
134134

135-
var transcription = transcriptionEntity.getBody();
136-
137-
if (transcription == null) {
138-
logger.warn("No transcription returned for request: {}", audioResource);
139-
return new AudioTranscriptionResponse(null);
140-
}
135+
if (transcription == null) {
136+
logger.warn("No transcription returned for request: {}", audioResource);
137+
return new AudioTranscriptionResponse(null);
138+
}
141139

142-
AudioTranscription transcript = new AudioTranscription(transcription.text());
140+
AudioTranscription transcript = new AudioTranscription(transcription.text());
143141

144-
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);
142+
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);
145143

146-
return new AudioTranscriptionResponse(transcript,
147-
OpenAiAudioTranscriptionResponseMetadata.from(transcriptionEntity.getBody())
148-
.withRateLimit(rateLimits));
144+
return new AudioTranscriptionResponse(transcript,
145+
OpenAiAudioTranscriptionResponseMetadata.from(transcriptionEntity.getBody())
146+
.withRateLimit(rateLimits));
149147

150-
}
151-
else {
148+
}
149+
else {
152150

153-
ResponseEntity<String> transcriptionEntity = this.audioApi.createTranscription(requestBody,
154-
String.class);
151+
ResponseEntity<String> transcriptionEntity = this.retryTemplate
152+
.execute(ctx -> this.audioApi.createTranscription(request, String.class));
155153

156-
var transcription = transcriptionEntity.getBody();
154+
var transcription = transcriptionEntity.getBody();
157155

158-
if (transcription == null) {
159-
logger.warn("No transcription returned for request: {}", audioResource);
160-
return new AudioTranscriptionResponse(null);
161-
}
156+
if (transcription == null) {
157+
logger.warn("No transcription returned for request: {}", audioResource);
158+
return new AudioTranscriptionResponse(null);
159+
}
162160

163-
AudioTranscription transcript = new AudioTranscription(transcription);
161+
AudioTranscription transcript = new AudioTranscription(transcription);
164162

165-
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);
163+
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);
166164

167-
return new AudioTranscriptionResponse(transcript,
168-
OpenAiAudioTranscriptionResponseMetadata.from(transcriptionEntity.getBody())
169-
.withRateLimit(rateLimits));
170-
}
171-
});
165+
return new AudioTranscriptionResponse(transcript,
166+
OpenAiAudioTranscriptionResponseMetadata.from(transcriptionEntity.getBody())
167+
.withRateLimit(rateLimits));
168+
}
172169
}
173170

174-
OpenAiAudioApi.TranscriptionRequest createRequestBody(AudioTranscriptionPrompt request) {
171+
OpenAiAudioApi.TranscriptionRequest createRequest(AudioTranscriptionPrompt transcriptionPrompt) {
175172

176173
OpenAiAudioTranscriptionOptions options = this.defaultOptions;
177174

178-
if (request.getOptions() != null) {
179-
if (request.getOptions() instanceof OpenAiAudioTranscriptionOptions runtimeOptions) {
175+
if (transcriptionPrompt.getOptions() != null) {
176+
if (transcriptionPrompt.getOptions() instanceof OpenAiAudioTranscriptionOptions runtimeOptions) {
180177
options = this.merge(runtimeOptions, options);
181178
}
182179
else {
183180
throw new IllegalArgumentException("Prompt options are not of type TranscriptionOptions: "
184-
+ request.getOptions().getClass().getSimpleName());
181+
+ transcriptionPrompt.getOptions().getClass().getSimpleName());
185182
}
186183
}
187184

188185
return OpenAiAudioApi.TranscriptionRequest.builder()
189-
.withFile(toBytes(request.getInstructions()))
186+
.withFile(toBytes(transcriptionPrompt.getInstructions()))
190187
.withResponseFormat(options.getResponseFormat())
191188
.withPrompt(options.getPrompt())
192189
.withTemperature(options.getTemperature())

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ public OpenAiEmbeddingModel(OpenAiApi openAiApi) {
6969
*/
7070
public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode) {
7171
this(openAiApi, metadataMode,
72-
OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(),
73-
RetryUtils.DEFAULT_RETRY_TEMPLATE);
72+
OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build());
7473
}
7574

7675
/**
@@ -110,42 +109,45 @@ public List<Double> embed(Document document) {
110109
return this.embed(document.getFormattedContent(this.metadataMode));
111110
}
112111

113-
@SuppressWarnings("unchecked")
114112
@Override
115113
public EmbeddingResponse call(EmbeddingRequest request) {
116114

117-
return this.retryTemplate.execute(ctx -> {
118-
119-
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<List<String>> apiRequest = (this.defaultOptions != null)
120-
? new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
121-
this.defaultOptions.getModel(), this.defaultOptions.getEncodingFormat(),
122-
this.defaultOptions.getDimensions(), this.defaultOptions.getUser())
123-
: new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
124-
OpenAiApi.DEFAULT_EMBEDDING_MODEL);
125-
126-
if (request.getOptions() != null && !EmbeddingOptions.EMPTY.equals(request.getOptions())) {
127-
apiRequest = ModelOptionsUtils.merge(request.getOptions(), apiRequest,
128-
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class);
129-
}
115+
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<List<String>> apiRequest = createRequest(request);
130116

131-
EmbeddingList<OpenAiApi.Embedding> apiEmbeddingResponse = this.openAiApi.embeddings(apiRequest).getBody();
117+
EmbeddingList<OpenAiApi.Embedding> apiEmbeddingResponse = this.retryTemplate
118+
.execute(ctx -> this.openAiApi.embeddings(apiRequest).getBody());
132119

133-
if (apiEmbeddingResponse == null) {
134-
logger.warn("No embeddings returned for request: {}", request);
135-
return new EmbeddingResponse(List.of());
136-
}
120+
if (apiEmbeddingResponse == null) {
121+
logger.warn("No embeddings returned for request: {}", request);
122+
return new EmbeddingResponse(List.of());
123+
}
137124

138-
var metadata = new EmbeddingResponseMetadata(apiEmbeddingResponse.model(),
139-
OpenAiUsage.from(apiEmbeddingResponse.usage()));
125+
var metadata = new EmbeddingResponseMetadata(apiEmbeddingResponse.model(),
126+
OpenAiUsage.from(apiEmbeddingResponse.usage()));
140127

141-
List<Embedding> embeddings = apiEmbeddingResponse.data()
142-
.stream()
143-
.map(e -> new Embedding(e.embedding(), e.index()))
144-
.toList();
128+
List<Embedding> embeddings = apiEmbeddingResponse.data()
129+
.stream()
130+
.map(e -> new Embedding(e.embedding(), e.index()))
131+
.toList();
145132

146-
return new EmbeddingResponse(embeddings, metadata);
133+
return new EmbeddingResponse(embeddings, metadata);
134+
}
147135

148-
});
136+
@SuppressWarnings("unchecked")
137+
private OpenAiApi.EmbeddingRequest<List<String>> createRequest(EmbeddingRequest request) {
138+
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<List<String>> apiRequest = (this.defaultOptions != null)
139+
? new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
140+
this.defaultOptions.getModel(), this.defaultOptions.getEncodingFormat(),
141+
this.defaultOptions.getDimensions(), this.defaultOptions.getUser())
142+
: new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
143+
OpenAiApi.DEFAULT_EMBEDDING_MODEL);
144+
145+
if (request.getOptions() != null && !EmbeddingOptions.EMPTY.equals(request.getOptions())) {
146+
apiRequest = ModelOptionsUtils.merge(request.getOptions(), apiRequest,
147+
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class);
148+
}
149+
150+
return apiRequest;
149151
}
150152

151153
}

0 commit comments

Comments
 (0)