From b9a0addd4e8a48a2f8a4e598cd531acb1f80b84d Mon Sep 17 00:00:00 2001 From: YunKui Lu Date: Wed, 18 Jun 2025 00:03:31 +0800 Subject: [PATCH] Fix infinite recursion in `getMimeType(Path)` method Signed-off-by: YunKui Lu --- .../ai/vertexai/gemini/MimeTypeDetector.java | 5 +- .../gemini/MimeTypeDetectorTests.java | 76 +++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/MimeTypeDetectorTests.java diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java index fe5e8e52e6e..52b341fec8e 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/MimeTypeDetector.java @@ -55,7 +55,8 @@ public abstract class MimeTypeDetector { /** * List of all MIME types supported by the Vertex Gemini API. */ - private static final Map GEMINI_MIME_TYPES = new HashMap<>(); + // exposed for testing purposes + static final Map GEMINI_MIME_TYPES = new HashMap<>(); public static MimeType getMimeType(URL url) { return getMimeType(url.getFile()); @@ -70,7 +71,7 @@ public static MimeType getMimeType(File file) { } public static MimeType getMimeType(Path path) { - return getMimeType(path.getFileName()); + return getMimeType(path.toUri()); } public static MimeType getMimeType(Resource resource) { diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/MimeTypeDetectorTests.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/MimeTypeDetectorTests.java new file mode 100644 index 00000000000..543c0fbd540 --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/MimeTypeDetectorTests.java @@ -0,0 +1,76 @@ +package org.springframework.ai.vertexai.gemini; + +import java.io.File; +import java.net.MalformedURLException; +import java.net.URI; +import java.nio.file.Path; +import java.util.stream.Stream; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.core.io.PathResource; +import org.springframework.util.MimeType; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.vertexai.gemini.MimeTypeDetector.GEMINI_MIME_TYPES; + +/** + * @author YunKui Lu + */ +class MimeTypeDetectorTests { + + private static Stream provideMimeTypes() { + return GEMINI_MIME_TYPES.entrySet().stream().map(entry -> Arguments.of(entry.getKey(), entry.getValue())); + } + + @ParameterizedTest + @MethodSource("provideMimeTypes") + void getMimeTypeByURLPath(String extension, MimeType expectedMimeType) throws MalformedURLException { + String path = "https://testhost/test." + extension; + MimeType mimeType = MimeTypeDetector.getMimeType(URI.create(path).toURL()); + assertThat(mimeType).isEqualTo(expectedMimeType); + } + + @ParameterizedTest + @MethodSource("provideMimeTypes") + void getMimeTypeByURI(String extension, MimeType expectedMimeType) { + String path = "https://testhost/test." + extension; + MimeType mimeType = MimeTypeDetector.getMimeType(URI.create(path)); + assertThat(mimeType).isEqualTo(expectedMimeType); + } + + @ParameterizedTest + @MethodSource("provideMimeTypes") + void getMimeTypeByFile(String extension, MimeType expectedMimeType) { + String path = "test." + extension; + MimeType mimeType = MimeTypeDetector.getMimeType(new File(path)); + assertThat(mimeType).isEqualTo(expectedMimeType); + } + + @ParameterizedTest + @MethodSource("provideMimeTypes") + void getMimeTypeByPath(String extension, MimeType expectedMimeType) { + String path = "test." + extension; + MimeType mimeType = MimeTypeDetector.getMimeType(Path.of(path)); + assertThat(mimeType).isEqualTo(expectedMimeType); + } + + @ParameterizedTest + @MethodSource("provideMimeTypes") + void getMimeTypeByResource(String extension, MimeType expectedMimeType) { + String path = "test." + extension; + MimeType mimeType = MimeTypeDetector.getMimeType(new PathResource(path)); + assertThat(mimeType).isEqualTo(expectedMimeType); + } + + @ParameterizedTest + @MethodSource("provideMimeTypes") + void getMimeTypeByString(String extension, MimeType expectedMimeType) { + String path = "test." + extension; + MimeType mimeType = MimeTypeDetector.getMimeType(path); + assertThat(mimeType).isEqualTo(expectedMimeType); + } + +}