Skip to content

Commit f4a9b12

Browse files
authored
Merge branch 'spring-projects:main' into main
2 parents 0dc4ed6 + 8a5635d commit f4a9b12

File tree

7 files changed

+217
-92
lines changed

7 files changed

+217
-92
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiAudioTranscriptionModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.StructuredResponse.Word;
3535
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.TranscriptResponseFormat;
3636
import org.springframework.ai.azure.openai.metadata.AzureOpenAiAudioTranscriptionResponseMetadata;
37-
import org.springframework.ai.model.Model;
3837
import org.springframework.ai.model.ModelOptionsUtils;
38+
import org.springframework.ai.audio.transcription.TranscriptionModel;
3939
import org.springframework.core.io.Resource;
4040
import org.springframework.util.Assert;
4141
import org.springframework.util.StringUtils;
@@ -47,7 +47,7 @@
4747
*
4848
* @author Piotr Olaszewski
4949
*/
50-
public class AzureOpenAiAudioTranscriptionModel implements Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
50+
public class AzureOpenAiAudioTranscriptionModel implements TranscriptionModel {
5151

5252
private static final List<AudioTranscriptionFormat> JSON_FORMATS = List.of(AudioTranscriptionFormat.JSON,
5353
AudioTranscriptionFormat.VERBOSE_JSON);

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt;
2424
import org.springframework.ai.audio.transcription.AudioTranscriptionResponse;
2525
import org.springframework.ai.chat.metadata.RateLimit;
26-
import org.springframework.ai.model.Model;
26+
import org.springframework.ai.audio.transcription.TranscriptionModel;
2727
import org.springframework.ai.openai.api.OpenAiAudioApi;
2828
import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse;
2929
import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionResponseMetadata;
@@ -45,7 +45,7 @@
4545
* @see OpenAiAudioApi
4646
* @since 0.8.1
4747
*/
48-
public class OpenAiAudioTranscriptionModel implements Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
48+
public class OpenAiAudioTranscriptionModel implements TranscriptionModel {
4949

5050
private final Logger logger = LoggerFactory.getLogger(getClass());
5151

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.audio.transcription;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt;
21+
import org.springframework.ai.audio.transcription.AudioTranscriptionResponse;
22+
import org.springframework.ai.audio.transcription.TranscriptionModel;
23+
import org.springframework.ai.model.SimpleApiKey;
24+
import org.springframework.ai.openai.OpenAiAudioTranscriptionModel;
25+
import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions;
26+
import org.springframework.ai.openai.api.OpenAiAudioApi;
27+
import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat;
28+
import org.springframework.ai.retry.RetryUtils;
29+
import org.springframework.beans.factory.annotation.Autowired;
30+
import org.springframework.boot.test.autoconfigure.web.client.RestClientTest;
31+
import org.springframework.context.annotation.Bean;
32+
import org.springframework.context.annotation.Configuration;
33+
import org.springframework.core.io.ClassPathResource;
34+
import org.springframework.http.HttpMethod;
35+
import org.springframework.http.MediaType;
36+
import org.springframework.test.web.client.MockRestServiceServer;
37+
import org.springframework.util.LinkedMultiValueMap;
38+
import org.springframework.web.client.RestClient;
39+
import org.springframework.web.reactive.function.client.WebClient;
40+
41+
import static org.assertj.core.api.Assertions.assertThat;
42+
import static org.springframework.test.web.client.match.MockRestRequestMatchers.method;
43+
import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo;
44+
import static org.springframework.test.web.client.response.MockRestResponseCreators.withSuccess;
45+
46+
@RestClientTest(OpenAiAudioTranscriptionModelTests.Config.class)
47+
class OpenAiAudioTranscriptionModelTests {
48+
49+
@Autowired
50+
private MockRestServiceServer server;
51+
52+
@Autowired
53+
private TranscriptionModel transcriptionModel;
54+
55+
@Test
56+
void transcribeRequestReturnsResponseCorrectly() {
57+
String mockResponse = """
58+
{
59+
"text": "All your bases are belong to us"
60+
}
61+
""".stripIndent();
62+
63+
this.server.expect(requestTo("https://api.openai.com/v1/audio/transcriptions"))
64+
.andExpect(method(HttpMethod.POST))
65+
.andRespond(withSuccess(mockResponse, MediaType.APPLICATION_JSON));
66+
67+
String transcription = this.transcriptionModel.transcribe(new ClassPathResource("/speech.flac"));
68+
69+
assertThat(transcription).isEqualTo("All your bases are belong to us");
70+
this.server.verify();
71+
}
72+
73+
@Test
74+
void callWithDefaultOptions() {
75+
String mockResponse = """
76+
{
77+
"text": "Hello, this is a test transcription."
78+
}
79+
""".stripIndent();
80+
81+
this.server.expect(requestTo("https://api.openai.com/v1/audio/transcriptions"))
82+
.andExpect(method(HttpMethod.POST))
83+
.andRespond(withSuccess(mockResponse, MediaType.APPLICATION_JSON));
84+
85+
AudioTranscriptionPrompt prompt = new AudioTranscriptionPrompt(new ClassPathResource("/speech.flac"));
86+
AudioTranscriptionResponse response = this.transcriptionModel.call(prompt);
87+
88+
assertThat(response.getResult().getOutput()).isEqualTo("Hello, this is a test transcription.");
89+
this.server.verify();
90+
}
91+
92+
@Test
93+
void transcribeWithOptions() {
94+
String mockResponse = """
95+
{
96+
"text": "Hello, this is a test transcription with options."
97+
}
98+
""".stripIndent();
99+
100+
this.server.expect(requestTo("https://api.openai.com/v1/audio/transcriptions"))
101+
.andExpect(method(HttpMethod.POST))
102+
.andRespond(withSuccess(mockResponse, MediaType.APPLICATION_JSON));
103+
104+
OpenAiAudioTranscriptionOptions options = OpenAiAudioTranscriptionOptions.builder()
105+
.temperature(0.5f)
106+
.responseFormat(TranscriptResponseFormat.JSON)
107+
.build();
108+
109+
String transcription = this.transcriptionModel.transcribe(new ClassPathResource("/speech.flac"), options);
110+
111+
assertThat(transcription).isEqualTo("Hello, this is a test transcription with options.");
112+
this.server.verify();
113+
}
114+
115+
@Configuration
116+
static class Config {
117+
118+
@Bean
119+
public OpenAiAudioApi openAiAudioApi(RestClient.Builder builder) {
120+
return new OpenAiAudioApi("https://api.openai.com", new SimpleApiKey("test-api-key"),
121+
new LinkedMultiValueMap<>(), builder, WebClient.builder(),
122+
RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
123+
}
124+
125+
@Bean
126+
public OpenAiAudioTranscriptionModel openAiAudioTranscriptionModel(OpenAiAudioApi audioApi) {
127+
return new OpenAiAudioTranscriptionModel(audioApi);
128+
}
129+
130+
}
131+
132+
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionModelIT.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class OpenAiTranscriptionModelIT extends AbstractIT {
4040
private Resource audioFile;
4141

4242
@Test
43-
void transcriptionTest() {
43+
void callTest() {
4444
OpenAiAudioTranscriptionOptions transcriptionOptions = OpenAiAudioTranscriptionOptions.builder()
4545
.responseFormat(TranscriptResponseFormat.TEXT)
4646
.temperature(0f)
@@ -53,7 +53,7 @@ void transcriptionTest() {
5353
}
5454

5555
@Test
56-
void transcriptionTestWithOptions() {
56+
void callTestWithOptions() {
5757
OpenAiAudioApi.TranscriptResponseFormat responseFormat = OpenAiAudioApi.TranscriptResponseFormat.VTT;
5858

5959
OpenAiAudioTranscriptionOptions transcriptionOptions = OpenAiAudioTranscriptionOptions.builder()
@@ -69,4 +69,24 @@ void transcriptionTestWithOptions() {
6969
assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue();
7070
}
7171

72+
@Test
73+
void transcribeTest() {
74+
String response = this.transcriptionModel.transcribe(this.audioFile);
75+
assertThat(response).isNotNull();
76+
assertThat(response.toLowerCase().contains("fellow")).isTrue();
77+
}
78+
79+
@Test
80+
void transcribeTestWithOptions() {
81+
OpenAiAudioTranscriptionOptions transcriptionOptions = OpenAiAudioTranscriptionOptions.builder()
82+
.language("en")
83+
.prompt("Ask not this, but ask that")
84+
.temperature(0f)
85+
.responseFormat(TranscriptResponseFormat.TEXT)
86+
.build();
87+
String response = this.transcriptionModel.transcribe(this.audioFile, transcriptionOptions);
88+
assertThat(response).isNotNull();
89+
assertThat(response.toLowerCase().contains("fellow")).isTrue();
90+
}
91+
7292
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/TranscriptionModelTests.java

Lines changed: 0 additions & 86 deletions
This file was deleted.

models/spring-ai-openai/src/test/resources/speech.flac

Whitespace-only changes.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.audio.transcription;
18+
19+
import org.springframework.ai.model.Model;
20+
import org.springframework.core.io.Resource;
21+
22+
/**
23+
* A transcription model is a type of AI model that converts audio to text. This is also
24+
* known as Speech-to-Text.
25+
*
26+
* @author Mudabir Hussain
27+
* @since 1.0.0
28+
*/
29+
public interface TranscriptionModel extends Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
30+
31+
/**
32+
* Transcribes the audio from the given prompt.
33+
* @param transcriptionPrompt The prompt containing the audio resource and options.
34+
* @return The transcription response.
35+
*/
36+
AudioTranscriptionResponse call(AudioTranscriptionPrompt transcriptionPrompt);
37+
38+
/**
39+
* A convenience method for transcribing an audio resource.
40+
* @param resource The audio resource to transcribe.
41+
* @return The transcribed text.
42+
*/
43+
default String transcribe(Resource resource) {
44+
AudioTranscriptionPrompt prompt = new AudioTranscriptionPrompt(resource);
45+
return this.call(prompt).getResult().getOutput();
46+
}
47+
48+
/**
49+
* A convenience method for transcribing an audio resource with the given options.
50+
* @param resource The audio resource to transcribe.
51+
* @param options The transcription options.
52+
* @return The transcribed text.
53+
*/
54+
default String transcribe(Resource resource, AudioTranscriptionOptions options) {
55+
AudioTranscriptionPrompt prompt = new AudioTranscriptionPrompt(resource, options);
56+
return this.call(prompt).getResult().getOutput();
57+
}
58+
59+
}

0 commit comments

Comments
 (0)