diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java index fe5e8e52e6e..52b341fec8e 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java @@ -55,7 +55,8 @@ public abstract class MimeTypeDetector { /** * List of all MIME types supported by the Vertex Gemini API. */ - private static final Map GEMINI_MIME_TYPES = new HashMap<>(); + // exposed for testing purposes + static final Map GEMINI_MIME_TYPES = new HashMap<>(); public static MimeType getMimeType(URL url) { return getMimeType(url.getFile()); @@ -70,7 +71,7 @@ public static MimeType getMimeType(File file) { } public static MimeType getMimeType(Path path) { - return getMimeType(path.getFileName()); + return getMimeType(path.toUri()); } public static MimeType getMimeType(Resource resource) { diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/MimeTypeDetectorTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/MimeTypeDetectorTests.java new file mode 100644 index 00000000000..543c0fbd540 --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/MimeTypeDetectorTests.java @@ -0,0 +1,76 @@ +package org.springframework.ai.vertexai.gemini; + +import java.io.File; +import java.net.MalformedURLException; +import java.net.URI; +import java.nio.file.Path; +import java.util.stream.Stream; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.core.io.PathResource; +import org.springframework.util.MimeType; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.vertexai.gemini.MimeTypeDetector.GEMINI_MIME_TYPES; + +/** + * @author YunKui Lu + */ +class MimeTypeDetectorTests { + + private static Stream provideMimeTypes() { + return GEMINI_MIME_TYPES.entrySet().stream().map(entry -> Arguments.of(entry.getKey(), entry.getValue())); + } + + @ParameterizedTest + @MethodSource("provideMimeTypes") + void getMimeTypeByURLPath(String extension, MimeType expectedMimeType) throws MalformedURLException { + String path = "https://testhost/test." + extension; + MimeType mimeType = MimeTypeDetector.getMimeType(URI.create(path).toURL()); + assertThat(mimeType).isEqualTo(expectedMimeType); + } + + @ParameterizedTest + @MethodSource("provideMimeTypes") + void getMimeTypeByURI(String extension, MimeType expectedMimeType) { + String path = "https://testhost/test." + extension; + MimeType mimeType = MimeTypeDetector.getMimeType(URI.create(path)); + assertThat(mimeType).isEqualTo(expectedMimeType); + } + + @ParameterizedTest + @MethodSource("provideMimeTypes") + void getMimeTypeByFile(String extension, MimeType expectedMimeType) { + String path = "test." + extension; + MimeType mimeType = MimeTypeDetector.getMimeType(new File(path)); + assertThat(mimeType).isEqualTo(expectedMimeType); + } + + @ParameterizedTest + @MethodSource("provideMimeTypes") + void getMimeTypeByPath(String extension, MimeType expectedMimeType) { + String path = "test." + extension; + MimeType mimeType = MimeTypeDetector.getMimeType(Path.of(path)); + assertThat(mimeType).isEqualTo(expectedMimeType); + } + + @ParameterizedTest + @MethodSource("provideMimeTypes") + void getMimeTypeByResource(String extension, MimeType expectedMimeType) { + String path = "test." + extension; + MimeType mimeType = MimeTypeDetector.getMimeType(new PathResource(path)); + assertThat(mimeType).isEqualTo(expectedMimeType); + } + + @ParameterizedTest + @MethodSource("provideMimeTypes") + void getMimeTypeByString(String extension, MimeType expectedMimeType) { + String path = "test." + extension; + MimeType mimeType = MimeTypeDetector.getMimeType(path); + assertThat(mimeType).isEqualTo(expectedMimeType); + } + +}