3030import org .springframework .ai .openai .api .common .ApiUtils ;
3131import org .springframework .boot .context .properties .bind .ConstructorBinding ;
3232import org .springframework .core .ParameterizedTypeReference ;
33+ import org .springframework .http .MediaType ;
3334import org .springframework .http .ResponseEntity ;
3435import org .springframework .util .Assert ;
3536import org .springframework .util .CollectionUtils ;
37+ import org .springframework .util .MultiValueMap ;
3638import org .springframework .web .client .RestClient ;
3739import org .springframework .web .reactive .function .client .WebClient ;
3840
4244 * OpenAI Embedding API: https://platform.openai.com/docs/api-reference/embeddings.
4345 *
4446 * @author Christian Tzolov
47+ * @author Michael Lavelle
4548 */
4649public class OpenAiApi {
4750
@@ -50,6 +53,9 @@ public class OpenAiApi {
5053 private static final Predicate <String > SSE_DONE_PREDICATE = "[DONE]" ::equals ;
5154
5255 private final RestClient restClient ;
56+
57+ private final RestClient multipartRestClient ;
58+
5359 private final WebClient webClient ;
5460
5561 /**
@@ -86,6 +92,15 @@ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClie
8692 .defaultStatusHandler (ApiUtils .DEFAULT_RESPONSE_ERROR_HANDLER )
8793 .build ();
8894
95+ this .multipartRestClient = restClientBuilder
96+ .baseUrl (baseUrl )
97+ .defaultHeaders (multipartFormDataHeaders -> {
98+ multipartFormDataHeaders .setBearerAuth (openAiToken );
99+ multipartFormDataHeaders .setContentType (MediaType .MULTIPART_FORM_DATA );
100+ })
101+ .defaultStatusHandler (ApiUtils .DEFAULT_RESPONSE_ERROR_HANDLER )
102+ .build ();
103+
89104 this .webClient = WebClient .builder ()
90105 .baseUrl (baseUrl )
91106 .defaultHeaders (ApiUtils .getJsonContentHeaders (openAiToken ))
@@ -97,7 +112,7 @@ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClie
97112 * <a href="https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo">GPT-4 and GPT-4 Turbo</a> and
98113 * <a href="https://platform.openai.com/docs/models/gpt-3-5-turbo">GPT-3.5 Turbo</a>.
99114 */
100- enum ChatModel {
115+ public enum ChatModel {
101116 /**
102117 * (New) GPT-4 Turbo - latest GPT-4 model intended to reduce cases
103118 * of “laziness” where the model doesn’t complete a task.
@@ -169,42 +184,6 @@ public String getValue() {
169184 }
170185 }
171186
172- /**
173- * OpenAI Embeddings Models:
174- * <a href="https://platform.openai.com/docs/models/embeddings">Embeddings</a>.
175- */
176- enum EmbeddingModel {
177-
178- /**
179- * Most capable embedding model for both english and non-english tasks.
180- * DIMENSION: 3072
181- */
182- TEXT_EMBEDDING_3_LARGE ("text-embedding-3-large" ),
183-
184- /**
185- * Increased performance over 2nd generation ada embedding model.
186- * DIMENSION: 1536
187- */
188- TEXT_EMBEDDING_3_SMALL ("text-embedding-3-small" ),
189-
190- /**
191- * Most capable 2nd generation embedding model, replacing 16 first
192- * generation models.
193- * DIMENSION: 1536
194- */
195- TEXT_EMBEDDING_ADA_002 ("text-embedding-ada-002" );
196-
197- public final String value ;
198-
199- EmbeddingModel (String value ) {
200- this .value = value ;
201- }
202-
203- public String getValue () {
204- return value ;
205- }
206- }
207-
208187 /**
209188 * Represents a tool the model may call. Currently, only functions are supported as a tool.
210189 *
@@ -708,6 +687,44 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
708687 .map (content -> ModelOptionsUtils .jsonToObject (content , ChatCompletionChunk .class ));
709688 }
710689
690+ // Embeddings API
691+
692+ /**
693+ * OpenAI Embeddings Models:
694+ * <a href="https://platform.openai.com/docs/models/embeddings">Embeddings</a>.
695+ */
696+ public enum EmbeddingModel {
697+
698+ /**
699+ * Most capable embedding model for both english and non-english tasks.
700+ * DIMENSION: 3072
701+ */
702+ TEXT_EMBEDDING_3_LARGE ("text-embedding-3-large" ),
703+
704+ /**
705+ * Increased performance over 2nd generation ada embedding model.
706+ * DIMENSION: 1536
707+ */
708+ TEXT_EMBEDDING_3_SMALL ("text-embedding-3-small" ),
709+
710+ /**
711+ * Most capable 2nd generation embedding model, replacing 16 first
712+ * generation models.
713+ * DIMENSION: 1536
714+ */
715+ TEXT_EMBEDDING_ADA_002 ("text-embedding-ada-002" );
716+
717+ public final String value ;
718+
719+ EmbeddingModel (String value ) {
720+ this .value = value ;
721+ }
722+
723+ public String getValue () {
724+ return value ;
725+ }
726+ }
727+
711728 /**
712729 * Represents an embedding vector returned by embedding endpoint.
713730 *
@@ -824,5 +841,87 @@ public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<
824841 .toEntity (new ParameterizedTypeReference <>() {
825842 });
826843 }
844+
845+ // Transcription API
846+
847+ // @JsonInclude(Include.NON_NULL)
848+ // public record Transcription(
849+ // @JsonProperty("text") String text) {
850+ // }
851+
852+ // /**
853+ // *
854+ // * @param model ID of the model to use.
855+ // * @param language The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency.
856+ // * @param prompt An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language.
857+ // * @param responseFormat An object specifying the format that the model must output.
858+ // * @param temperature What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output
859+ // * more random, while lower values like 0.2 will make it more focused and deterministic. */
860+ // @JsonInclude(Include.NON_NULL)
861+ // public record TranscriptionRequest (
862+ // @JsonProperty("model") String model,
863+ // @JsonProperty("language") String language,
864+ // @JsonProperty("prompt") String prompt,
865+ // @JsonProperty("response_format") ResponseFormat responseFormat,
866+ // @JsonProperty("temperature") Float temperature) {
867+
868+ // /**
869+ // * Shortcut constructor for a transcription request with the given model and temperature
870+ // *
871+ // * @param model ID of the model to use.
872+ // * @param temperature What sampling temperature to use, between 0 and 1.
873+ // */
874+ // public TranscriptionRequest(String model, Float temperature) {
875+ // this(model, null, null, null, temperature);
876+ // }
877+
878+ // public TranscriptionRequest() {
879+ // this(null, null, null, null, null);
880+ // }
881+
882+ // /**
883+ // * An object specifying the format that the model must output.
884+ // * @param type Must be one of 'text' or 'json_object'.
885+ // */
886+ // @JsonInclude(Include.NON_NULL)
887+ // public record ResponseFormat(
888+ // @JsonProperty("type") String type) {
889+ // }
890+ // }
891+
892+ // /**
893+ // * Creates a model response for the given transcription.
894+ // *
895+ // * @param transcriptionRequest The transcription request.
896+ // * @return Entity response with {@link Transcription} as a body and HTTP status code and headers.
897+ // */
898+ // public ResponseEntity<Transcription> transcriptionEntityJson(MultiValueMap<String, Object> transcriptionRequest) {
899+
900+ // Assert.notNull(transcriptionRequest, "The request body can not be null.");
901+
902+ // return this.multipartRestClient.post()
903+ // .uri("/v1/audio/transcriptions")
904+ // .body(transcriptionRequest)
905+ // .retrieve()
906+ // .toEntity(Transcription.class);
907+ // }
908+
909+ // /**
910+ // * Creates a model response for the given transcription.
911+ // *
912+ // * @param transcriptionRequest The transcription request.
913+ // * @return Entity response with {@link String} as a body and HTTP status code and headers.
914+ // */
915+ // public ResponseEntity<String> transcriptionEntityText(MultiValueMap<String, Object> transcriptionRequest) {
916+
917+ // Assert.notNull(transcriptionRequest, "The request body can not be null.");
918+
919+ // return this.multipartRestClient.post()
920+ // .uri("/v1/audio/transcriptions")
921+ // .body(transcriptionRequest)
922+ // .accept(MediaType.TEXT_PLAIN)
923+ // .retrieve()
924+ // .toEntity(String.class);
925+ // }
827926}
828927// @formatter:on
0 commit comments