Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,25 @@
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat;
import org.springframework.ai.openai.api.common.OpenAiApiException;
import org.springframework.ai.openai.audio.speech.Speech;
import org.springframework.ai.openai.audio.speech.SpeechModel;
import org.springframework.ai.openai.audio.speech.SpeechPrompt;
import org.springframework.ai.openai.audio.speech.SpeechResponse;
import org.springframework.ai.openai.audio.speech.StreamingSpeechModel;
import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;

import java.time.Duration;

/**
* OpenAI audio speech client implementation for backed by {@link OpenAiAudioApi}.
*
* @author Ahmed Yousri
* @author Hyunjoon Choi
* @author Thomas Vitale
* @see OpenAiAudioApi
* @since 1.0.0-M1
*/
Expand All @@ -63,11 +62,7 @@ public class OpenAiAudioSpeechModel implements SpeechModel, StreamingSpeechModel
/**
* The retry template used to retry the OpenAI Audio API calls.
*/
public final RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(10)
.retryOn(OpenAiApiException.class)
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
.build();
private final RetryTemplate retryTemplate;

/**
* Low-level access to the OpenAI Audio API.
Expand Down Expand Up @@ -98,10 +93,25 @@ public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi) {
* options.
*/
public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions options) {
this(audioApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

/**
* Initializes a new instance of the OpenAiAudioSpeechModel class with the provided
* OpenAiAudioApi and options.
* @param audioApi The OpenAiAudioApi to use for speech synthesis.
* @param options The OpenAiAudioSpeechOptions containing the speech synthesis
* options.
* @param retryTemplate The retry template.
*/
public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions options,
RetryTemplate retryTemplate) {
Assert.notNull(audioApi, "OpenAiAudioApi must not be null");
Assert.notNull(options, "OpenAiSpeechOptions must not be null");
Assert.notNull(options, "RetryTemplate must not be null");
this.audioApi = audioApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}

@Override
Expand All @@ -113,40 +123,43 @@ public byte[] call(String text) {
@Override
public SpeechResponse call(SpeechPrompt speechPrompt) {

return this.retryTemplate.execute(ctx -> {
OpenAiAudioApi.SpeechRequest speechRequest = createRequest(speechPrompt);

OpenAiAudioApi.SpeechRequest speechRequest = createRequestBody(speechPrompt);
ResponseEntity<byte[]> speechEntity = this.retryTemplate
.execute(ctx -> this.audioApi.createSpeech(speechRequest));

ResponseEntity<byte[]> speechEntity = this.audioApi.createSpeech(speechRequest);
var speech = speechEntity.getBody();

if (speech == null) {
logger.warn("No speech response returned for speechRequest: {}", speechRequest);
return new SpeechResponse(new Speech(new byte[0]));
}
var speech = speechEntity.getBody();

RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(speechEntity);
if (speech == null) {
logger.warn("No speech response returned for speechRequest: {}", speechRequest);
return new SpeechResponse(new Speech(new byte[0]));
}

return new SpeechResponse(new Speech(speech), new OpenAiAudioSpeechResponseMetadata(rateLimits));
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(speechEntity);

});
return new SpeechResponse(new Speech(speech), new OpenAiAudioSpeechResponseMetadata(rateLimits));
}

/**
* Streams the audio response for the given speech prompt.
* @param prompt The speech prompt containing the text and options for speech
* @param speechPrompt The speech prompt containing the text and options for speech
* synthesis.
* @return A Flux of SpeechResponse objects containing the streamed audio and
* metadata.
*/
@Override
public Flux<SpeechResponse> stream(SpeechPrompt prompt) {
return this.audioApi.stream(this.createRequestBody(prompt))
.map(entity -> new SpeechResponse(new Speech(entity.getBody()), new OpenAiAudioSpeechResponseMetadata(
OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity))));
public Flux<SpeechResponse> stream(SpeechPrompt speechPrompt) {

OpenAiAudioApi.SpeechRequest speechRequest = createRequest(speechPrompt);

Flux<ResponseEntity<byte[]>> speechEntity = this.retryTemplate
.execute(ctx -> this.audioApi.stream(speechRequest));

return speechEntity.map(entity -> new SpeechResponse(new Speech(entity.getBody()),
new OpenAiAudioSpeechResponseMetadata(OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity))));
}

private OpenAiAudioApi.SpeechRequest createRequestBody(SpeechPrompt request) {
private OpenAiAudioApi.SpeechRequest createRequest(SpeechPrompt request) {
OpenAiAudioSpeechOptions options = this.defaultOptions;

if (request.getOptions() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
*
* @author Michael Lavelle
* @author Christian Tzolov
* @author Thomas Vitale
* @see OpenAiAudioApi
* @since 0.8.1
*/
Expand All @@ -65,7 +66,7 @@ public class OpenAiAudioTranscriptionModel implements Model<AudioTranscriptionPr

private final OpenAiAudioTranscriptionOptions defaultOptions;

public final RetryTemplate retryTemplate;
private final RetryTemplate retryTemplate;

private final OpenAiAudioApi audioApi;

Expand All @@ -80,8 +81,7 @@ public OpenAiAudioTranscriptionModel(OpenAiAudioApi audioApi) {
.withModel(OpenAiAudioApi.WhisperModel.WHISPER_1.getValue())
.withResponseFormat(OpenAiAudioApi.TranscriptResponseFormat.JSON)
.withTemperature(0.7f)
.build(),
RetryUtils.DEFAULT_RETRY_TEMPLATE);
.build());
}

/**
Expand Down Expand Up @@ -119,74 +119,71 @@ public String call(Resource audioResource) {
}

@Override
public AudioTranscriptionResponse call(AudioTranscriptionPrompt request) {
public AudioTranscriptionResponse call(AudioTranscriptionPrompt transcriptionPrompt) {

return this.retryTemplate.execute(ctx -> {
Resource audioResource = transcriptionPrompt.getInstructions();

Resource audioResource = request.getInstructions();
OpenAiAudioApi.TranscriptionRequest request = createRequest(transcriptionPrompt);

OpenAiAudioApi.TranscriptionRequest requestBody = createRequestBody(request);
if (request.responseFormat().isJsonType()) {

if (requestBody.responseFormat().isJsonType()) {
ResponseEntity<StructuredResponse> transcriptionEntity = this.retryTemplate
.execute(ctx -> this.audioApi.createTranscription(request, StructuredResponse.class));

ResponseEntity<StructuredResponse> transcriptionEntity = this.audioApi.createTranscription(requestBody,
StructuredResponse.class);
var transcription = transcriptionEntity.getBody();

var transcription = transcriptionEntity.getBody();

if (transcription == null) {
logger.warn("No transcription returned for request: {}", audioResource);
return new AudioTranscriptionResponse(null);
}
if (transcription == null) {
logger.warn("No transcription returned for request: {}", audioResource);
return new AudioTranscriptionResponse(null);
}

AudioTranscription transcript = new AudioTranscription(transcription.text());
AudioTranscription transcript = new AudioTranscription(transcription.text());

RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);

return new AudioTranscriptionResponse(transcript,
OpenAiAudioTranscriptionResponseMetadata.from(transcriptionEntity.getBody())
.withRateLimit(rateLimits));
return new AudioTranscriptionResponse(transcript,
OpenAiAudioTranscriptionResponseMetadata.from(transcriptionEntity.getBody())
.withRateLimit(rateLimits));

}
else {
}
else {

ResponseEntity<String> transcriptionEntity = this.audioApi.createTranscription(requestBody,
String.class);
ResponseEntity<String> transcriptionEntity = this.retryTemplate
.execute(ctx -> this.audioApi.createTranscription(request, String.class));

var transcription = transcriptionEntity.getBody();
var transcription = transcriptionEntity.getBody();

if (transcription == null) {
logger.warn("No transcription returned for request: {}", audioResource);
return new AudioTranscriptionResponse(null);
}
if (transcription == null) {
logger.warn("No transcription returned for request: {}", audioResource);
return new AudioTranscriptionResponse(null);
}

AudioTranscription transcript = new AudioTranscription(transcription);
AudioTranscription transcript = new AudioTranscription(transcription);

RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);

return new AudioTranscriptionResponse(transcript,
OpenAiAudioTranscriptionResponseMetadata.from(transcriptionEntity.getBody())
.withRateLimit(rateLimits));
}
});
return new AudioTranscriptionResponse(transcript,
OpenAiAudioTranscriptionResponseMetadata.from(transcriptionEntity.getBody())
.withRateLimit(rateLimits));
}
}

OpenAiAudioApi.TranscriptionRequest createRequestBody(AudioTranscriptionPrompt request) {
OpenAiAudioApi.TranscriptionRequest createRequest(AudioTranscriptionPrompt transcriptionPrompt) {

OpenAiAudioTranscriptionOptions options = this.defaultOptions;

if (request.getOptions() != null) {
if (request.getOptions() instanceof OpenAiAudioTranscriptionOptions runtimeOptions) {
if (transcriptionPrompt.getOptions() != null) {
if (transcriptionPrompt.getOptions() instanceof OpenAiAudioTranscriptionOptions runtimeOptions) {
options = this.merge(runtimeOptions, options);
}
else {
throw new IllegalArgumentException("Prompt options are not of type TranscriptionOptions: "
+ request.getOptions().getClass().getSimpleName());
+ transcriptionPrompt.getOptions().getClass().getSimpleName());
}
}

return OpenAiAudioApi.TranscriptionRequest.builder()
.withFile(toBytes(request.getInstructions()))
.withFile(toBytes(transcriptionPrompt.getInstructions()))
.withResponseFormat(options.getResponseFormat())
.withPrompt(options.getPrompt())
.withTemperature(options.getTemperature())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ public OpenAiEmbeddingModel(OpenAiApi openAiApi) {
*/
public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode) {
this(openAiApi, metadataMode,
OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(),
RetryUtils.DEFAULT_RETRY_TEMPLATE);
OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build());
}

/**
Expand Down Expand Up @@ -110,42 +109,45 @@ public List<Double> embed(Document document) {
return this.embed(document.getFormattedContent(this.metadataMode));
}

@SuppressWarnings("unchecked")
@Override
public EmbeddingResponse call(EmbeddingRequest request) {

return this.retryTemplate.execute(ctx -> {

org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<List<String>> apiRequest = (this.defaultOptions != null)
? new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
this.defaultOptions.getModel(), this.defaultOptions.getEncodingFormat(),
this.defaultOptions.getDimensions(), this.defaultOptions.getUser())
: new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
OpenAiApi.DEFAULT_EMBEDDING_MODEL);

if (request.getOptions() != null && !EmbeddingOptions.EMPTY.equals(request.getOptions())) {
apiRequest = ModelOptionsUtils.merge(request.getOptions(), apiRequest,
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class);
}
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<List<String>> apiRequest = createRequest(request);

EmbeddingList<OpenAiApi.Embedding> apiEmbeddingResponse = this.openAiApi.embeddings(apiRequest).getBody();
EmbeddingList<OpenAiApi.Embedding> apiEmbeddingResponse = this.retryTemplate
.execute(ctx -> this.openAiApi.embeddings(apiRequest).getBody());

if (apiEmbeddingResponse == null) {
logger.warn("No embeddings returned for request: {}", request);
return new EmbeddingResponse(List.of());
}
if (apiEmbeddingResponse == null) {
logger.warn("No embeddings returned for request: {}", request);
return new EmbeddingResponse(List.of());
}

var metadata = new EmbeddingResponseMetadata(apiEmbeddingResponse.model(),
OpenAiUsage.from(apiEmbeddingResponse.usage()));
var metadata = new EmbeddingResponseMetadata(apiEmbeddingResponse.model(),
OpenAiUsage.from(apiEmbeddingResponse.usage()));

List<Embedding> embeddings = apiEmbeddingResponse.data()
.stream()
.map(e -> new Embedding(e.embedding(), e.index()))
.toList();
List<Embedding> embeddings = apiEmbeddingResponse.data()
.stream()
.map(e -> new Embedding(e.embedding(), e.index()))
.toList();

return new EmbeddingResponse(embeddings, metadata);
return new EmbeddingResponse(embeddings, metadata);
}

});
@SuppressWarnings("unchecked")
private OpenAiApi.EmbeddingRequest<List<String>> createRequest(EmbeddingRequest request) {
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<List<String>> apiRequest = (this.defaultOptions != null)
? new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
this.defaultOptions.getModel(), this.defaultOptions.getEncodingFormat(),
this.defaultOptions.getDimensions(), this.defaultOptions.getUser())
: new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
OpenAiApi.DEFAULT_EMBEDDING_MODEL);

if (request.getOptions() != null && !EmbeddingOptions.EMPTY.equals(request.getOptions())) {
apiRequest = ModelOptionsUtils.merge(request.getOptions(), apiRequest,
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest.class);
}

return apiRequest;
}

}
Loading