diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiFileApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiFileApi.java new file mode 100644 index 00000000000..fa043db373b --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiFileApi.java @@ -0,0 +1,385 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.api; + +import java.util.List; +import java.util.function.Consumer; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.model.ApiKey; +import org.springframework.ai.model.NoopApiKey; +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.ai.openai.api.common.OpenAiApiConstants; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.http.HttpHeaders; +import org.springframework.http.ResponseEntity; +import org.springframework.util.Assert; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.util.UriBuilder; + +/** + * OpenAI File API. + * + * @author Sun Yuhan + * @see Files API + */ +public class OpenAiFileApi { + + private final RestClient restClient; + + public OpenAiFileApi(String baseUrl, ApiKey apiKey, MultiValueMap headers, + RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { + Consumer authHeaders = h -> h.addAll(headers); + + this.restClient = restClientBuilder.clone() + .baseUrl(baseUrl) + .defaultHeaders(authHeaders) + .defaultStatusHandler(responseErrorHandler) + .defaultRequest(requestHeadersSpec -> { + if (!(apiKey instanceof NoopApiKey)) { + requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue()); + } + }) + .build(); + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Upload a file that can be used across various endpoints + * @param uploadFileRequest The request body + * @return Response entity containing the file object + */ + public ResponseEntity uploadFile(UploadFileRequest uploadFileRequest) { + MultiValueMap multipartBody = new LinkedMultiValueMap<>(); + multipartBody.add("file", new ByteArrayResource(uploadFileRequest.file()) { + @Override + public String getFilename() { + return uploadFileRequest.fileName(); + } + }); + multipartBody.add("purpose", uploadFileRequest.purpose()); + + return this.restClient.post().uri("/v1/files").body(multipartBody).retrieve().toEntity(FileObject.class); + } + + /** + * Returns a list of files + * @param listFileRequest The request body + * @return Response entity containing the files + */ + public ResponseEntity listFiles(ListFileRequest listFileRequest) { + return this.restClient.get().uri(uriBuilder -> { + UriBuilder builder = uriBuilder.path("/v1/files"); + if (null != listFileRequest.after()) { + builder = builder.queryParam("after", listFileRequest.after()); + } + if (null != listFileRequest.limit()) { + builder = builder.queryParam("limit", listFileRequest.limit()); + } + if (null != listFileRequest.order()) { + builder = builder.queryParam("order", listFileRequest.order()); + } + if (null != listFileRequest.purpose()) { + builder = builder.queryParam("purpose", listFileRequest.purpose()); + } + return builder.build(); + }).retrieve().toEntity(FileObjectResponse.class); + } + + /** + * Returns information about a specific file + * @param fileId The file id + * @return Response entity containing the file object + */ + public ResponseEntity retrieveFile(String fileId) { + return this.restClient.get().uri("/v1/files/%s".formatted(fileId)).retrieve().toEntity(FileObject.class); + } + + /** + * Delete a file + * @param fileId The file id + * @return Response entity containing the deletion status + */ + public ResponseEntity deleteFile(String fileId) { + return this.restClient.delete() + .uri("/v1/files/%s".formatted(fileId)) + .retrieve() + .toEntity(DeleteFileResponse.class); + } + + /** + * Returns the contents of the specified file + * @param fileId The file id + * @return Response entity containing the file content + */ + public ResponseEntity retrieveFileContent(String fileId) { + return this.restClient.get().uri("/v1/files/%s/content".formatted(fileId)).retrieve().toEntity(String.class); + } + + /** + * The intended purpose of the uploaded file + */ + public enum Purpose { + + // @formatter:off + /** + * Used in the Assistants API + */ + @JsonProperty("assistants") + ASSISTANTS("assistants"), + /** + * Used in the Batch API + */ + @JsonProperty("batch") + BATCH("batch"), + /** + * Used for fine-tuning + */ + @JsonProperty("fine-tune") + FINE_TUNE("fine-tune"), + /** + * Images used for vision fine-tuning + */ + @JsonProperty("vision") + VISION("vision"), + /** + * Flexible file type for any purpose + */ + @JsonProperty("user_data") + USER_DATA("user_data"), + /** + * Used for eval data sets + */ + @JsonProperty("evals") + EVALS("evals"); + // @formatter:on + + private final String value; + + Purpose(String value) { + this.value = value; + } + + public String getValue() { + return this.value; + } + + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record UploadFileRequest( + // @formatter:off + @JsonProperty("file") byte[] file, + @JsonProperty("fileName") String fileName, + @JsonProperty("purpose") String purpose) { + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private byte[] file; + + private String fileName; + + private String purpose; + + public Builder file(byte[] file) { + this.file = file; + return this; + } + + public Builder fileName(String fileName) { + this.fileName = fileName; + return this; + } + + public Builder purpose(String purpose) { + this.purpose = purpose; + return this; + } + + public Builder purpose(Purpose purpose) { + this.purpose = purpose.getValue(); + return this; + } + + public UploadFileRequest build() { + Assert.notNull(this.file, "file must not be empty"); + Assert.notNull(this.fileName, "fileName must not be empty"); + Assert.notNull(this.purpose, "purpose must not be empty"); + + return new UploadFileRequest(this.file, this.fileName, this.purpose); + } + + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ListFileRequest( + // @formatter:off + @JsonProperty("after") String after, + @JsonProperty("limit") Integer limit, + @JsonProperty("order") String order, + @JsonProperty("purpose") String purpose) { + // @formatter:on + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String after; + + private Integer limit; + + private String order; + + private String purpose; + + public Builder after(String after) { + this.after = after; + return this; + } + + public Builder limit(Integer limit) { + this.limit = limit; + return this; + } + + public Builder order(String order) { + this.order = order; + return this; + } + + public Builder purpose(String purpose) { + this.purpose = purpose; + return this; + } + + public Builder purpose(Purpose purpose) { + this.purpose = purpose.getValue(); + return this; + } + + public ListFileRequest build() { + return new ListFileRequest(this.after, this.limit, this.order, this.purpose); + } + + } + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record FileObject( + // @formatter:off + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("bytes") Integer bytes, + @JsonProperty("created_at") Integer createdAt, + @JsonProperty("expires_at") Integer expiresAt, + @JsonProperty("filename") String filename, + @JsonProperty("purpose") String purpose) { + // @formatter:on + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record FileObjectResponse( + // @formatter:off + @JsonProperty("data") List data, + @JsonProperty("object") String object + // @formatter:on + ) { + } + + @JsonInclude(JsonInclude.Include.NON_NULL) + public record DeleteFileResponse( + // @formatter:off + @JsonProperty("id") String id, + @JsonProperty("object") String object, + @JsonProperty("deleted") Boolean deleted) { + // @formatter:on + } + + public static class Builder { + + private String baseUrl = OpenAiApiConstants.DEFAULT_BASE_URL; + + private ApiKey apiKey; + + private MultiValueMap headers = new LinkedMultiValueMap<>(); + + private RestClient.Builder restClientBuilder = RestClient.builder(); + + private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER; + + public Builder baseUrl(String baseUrl) { + Assert.hasText(baseUrl, "baseUrl cannot be null or empty"); + this.baseUrl = baseUrl; + return this; + } + + public Builder apiKey(ApiKey apiKey) { + Assert.notNull(apiKey, "apiKey cannot be null"); + this.apiKey = apiKey; + return this; + } + + public Builder apiKey(String simpleApiKey) { + Assert.notNull(simpleApiKey, "simpleApiKey cannot be null"); + this.apiKey = new SimpleApiKey(simpleApiKey); + return this; + } + + public Builder headers(MultiValueMap headers) { + Assert.notNull(headers, "headers cannot be null"); + this.headers = headers; + return this; + } + + public Builder restClientBuilder(RestClient.Builder restClientBuilder) { + Assert.notNull(restClientBuilder, "restClientBuilder cannot be null"); + this.restClientBuilder = restClientBuilder; + return this; + } + + public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) { + Assert.notNull(responseErrorHandler, "responseErrorHandler cannot be null"); + this.responseErrorHandler = responseErrorHandler; + return this; + } + + public OpenAiFileApi build() { + Assert.notNull(this.apiKey, "apiKey must be set"); + return new OpenAiFileApi(this.baseUrl, this.apiKey, this.headers, this.restClientBuilder, + this.responseErrorHandler); + } + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiFileApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiFileApiBuilderTests.java new file mode 100644 index 00000000000..143fd9eaa68 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiFileApiBuilderTests.java @@ -0,0 +1,177 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.api; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Queue; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.model.ApiKey; +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link OpenAiFileApi}. + * + * @author Sun Yuhan + */ +class OpenAiFileApiBuilderTests { + + private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key"); + + private static final String TEST_BASE_URL = "https://test.openai.com"; + + @Test + void testMinimalBuilder() { + OpenAiFileApi api = OpenAiFileApi.builder().apiKey(TEST_API_KEY).build(); + + assertThat(api).isNotNull(); + } + + @Test + void testFullBuilder() { + MultiValueMap headers = new LinkedMultiValueMap<>(); + headers.add("Custom-Header", "test-value"); + RestClient.Builder restClientBuilder = RestClient.builder(); + ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); + + OpenAiFileApi api = OpenAiFileApi.builder() + .baseUrl(TEST_BASE_URL) + .apiKey(TEST_API_KEY) + .headers(headers) + .restClientBuilder(restClientBuilder) + .responseErrorHandler(errorHandler) + .build(); + + assertThat(api).isNotNull(); + } + + @Test + void testMissingApiKey() { + assertThatThrownBy(() -> OpenAiFileApi.builder().build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("apiKey must be set"); + } + + @Test + void testInvalidBaseUrl() { + assertThatThrownBy(() -> OpenAiFileApi.builder().baseUrl("").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + + assertThatThrownBy(() -> OpenAiFileApi.builder().baseUrl(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be null or empty"); + } + + @Test + void testInvalidHeaders() { + assertThatThrownBy(() -> OpenAiFileApi.builder().headers(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("headers cannot be null"); + } + + @Test + void testInvalidRestClientBuilder() { + assertThatThrownBy(() -> OpenAiFileApi.builder().restClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("restClientBuilder cannot be null"); + } + + @Test + void testInvalidResponseErrorHandler() { + assertThatThrownBy(() -> OpenAiFileApi.builder().responseErrorHandler(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("responseErrorHandler cannot be null"); + } + + @Nested + class MockRequests { + + MockWebServer mockWebServer; + + @BeforeEach + void setUp() throws IOException { + this.mockWebServer = new MockWebServer(); + this.mockWebServer.start(); + } + + @AfterEach + void tearDown() throws IOException { + this.mockWebServer.shutdown(); + } + + @Test + void dynamicApiKeyRestClient() throws InterruptedException { + Queue apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2"))); + OpenAiFileApi api = OpenAiFileApi.builder() + .apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue()) + .baseUrl(this.mockWebServer.url("/").toString()) + .build(); + + MockResponse mockResponse = new MockResponse().setResponseCode(200) + .addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(""" + { + "id": "file-abc123", + "object": "file", + "bytes": 120000, + "created_at": 1677610602, + "filename": "mydata.jsonl", + "purpose": "fine-tune" + } + """); + this.mockWebServer.enqueue(mockResponse); + this.mockWebServer.enqueue(mockResponse); + + OpenAiFileApi.UploadFileRequest request = new OpenAiFileApi.UploadFileRequest(new byte[] {}, "mydata.jsonl", + OpenAiFileApi.Purpose.USER_DATA.getValue()); + ResponseEntity response = api.uploadFile(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + RecordedRequest recordedRequest = this.mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1"); + + response = api.uploadFile(request); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + recordedRequest = this.mockWebServer.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2"); + } + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiFileApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiFileApiIT.java new file mode 100644 index 00000000000..79e4ad4711b --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiFileApiIT.java @@ -0,0 +1,102 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.api; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.UUID; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.core.io.ByteArrayResource; +import org.springframework.core.io.Resource; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link OpenAiFileApi}. + * + * @author Sun Yuhan + */ +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +class OpenAiFileApiIT { + + OpenAiFileApi fileApi = OpenAiFileApi.builder().apiKey(new SimpleApiKey(System.getenv("OPENAI_API_KEY"))).build(); + + @Test + void testOperationFileCompleteProcess() throws IOException { + String fileContent = "{\"key\":\"value\"}"; + Resource resource = new ByteArrayResource(fileContent.getBytes(StandardCharsets.UTF_8)); + String fileName = "test%s.jsonl".formatted(UUID.randomUUID().toString()); + OpenAiFileApi.Purpose purpose = OpenAiFileApi.Purpose.EVALS; + + // upload file + OpenAiFileApi.FileObject fileObject; + fileObject = this.fileApi + .uploadFile(OpenAiFileApi.UploadFileRequest.builder() + .file(toBytes(resource)) + .fileName(fileName) + .purpose(purpose) + .build()) + .getBody(); + + assertThat(fileObject).isNotNull(); + assertThat(fileObject.filename()).isEqualTo(fileName); + assertThat(fileObject.purpose()).isEqualTo(purpose.getValue()); + assertThat(fileObject.id()).isNotEmpty(); + + // list files + OpenAiFileApi.FileObjectResponse listFileResponse = this.fileApi + .listFiles(OpenAiFileApi.ListFileRequest.builder().purpose(purpose).build()) + .getBody(); + + assertThat(listFileResponse).isNotNull(); + assertThat(listFileResponse.data()).isNotEmpty(); + assertThat(listFileResponse.data().stream().map(OpenAiFileApi.FileObject::filename).toList()) + .contains(fileName); + + // retrieve file + OpenAiFileApi.FileObject object = this.fileApi.retrieveFile(fileObject.id()).getBody(); + + assertThat(object).isNotNull(); + assertThat(object.filename()).isEqualTo(fileName); + + // retrieve file content + String retrieveFileContent = this.fileApi.retrieveFileContent(fileObject.id()).getBody(); + + assertThat(retrieveFileContent).isNotNull(); + assertThat(retrieveFileContent).isEqualTo(fileContent); + + // delete file + OpenAiFileApi.DeleteFileResponse deleteResponse = this.fileApi.deleteFile(fileObject.id()).getBody(); + + assertThat(deleteResponse).isNotNull(); + assertThat(deleteResponse.deleted()).isEqualTo(true); + } + + private byte[] toBytes(Resource resource) { + try { + return resource.getInputStream().readAllBytes(); + } + catch (Exception e) { + throw new IllegalArgumentException("Failed to read resource: " + resource, e); + } + } + +}