Skip to content

Commit a964887

Browse files
joshlongilayaperumalg
authored andcommitted
first cut of aot improvements
Signed-off-by: Josh Long <[email protected]>
1 parent 76ca66d commit a964887

File tree

2 files changed

+30
-17
lines changed

2 files changed

+30
-17
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
package org.springframework.ai.openai.aot;
1818

19+
import org.springframework.ai.openai.OpenAiChatOptions;
20+
import org.springframework.ai.openai.api.OpenAiApi;
21+
import org.springframework.ai.openai.api.OpenAiAudioApi;
22+
import org.springframework.ai.openai.api.OpenAiImageApi;
1923
import org.springframework.aot.hint.MemberCategory;
2024
import org.springframework.aot.hint.RuntimeHints;
2125
import org.springframework.aot.hint.RuntimeHintsRegistrar;
@@ -37,10 +41,19 @@ public class OpenAiRuntimeHints implements RuntimeHintsRegistrar {
3741
@Override
3842
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
3943
var mcs = MemberCategory.values();
40-
41-
for (var tr : findJsonAnnotatedClassesInPackage("org.springframework.ai.openai")) {
44+
for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiChatOptions.class))) {
45+
hints.reflection().registerType(tr, mcs);
46+
}
47+
for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiApi.class))) {
48+
hints.reflection().registerType(tr, mcs);
49+
}
50+
for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiAudioApi.class))) {
51+
hints.reflection().registerType(tr, mcs);
52+
}
53+
for (var tr : findJsonAnnotatedClassesInPackage(OpenAiImageApi.class)) {
4254
hints.reflection().registerType(tr, mcs);
4355
}
56+
4457
}
4558

4659
}

spring-ai-model/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,19 @@
1616

1717
package org.springframework.ai.embedding;
1818

19-
import java.io.IOException;
20-
import java.util.Map;
21-
import java.util.Properties;
22-
import java.util.concurrent.atomic.AtomicInteger;
23-
import java.util.stream.Collectors;
24-
2519
import org.springframework.aot.hint.RuntimeHints;
2620
import org.springframework.aot.hint.RuntimeHintsRegistrar;
2721
import org.springframework.context.annotation.ImportRuntimeHints;
2822
import org.springframework.core.io.ClassPathResource;
2923
import org.springframework.core.io.Resource;
3024
import org.springframework.util.Assert;
3125

26+
import java.io.IOException;
27+
import java.util.Map;
28+
import java.util.Properties;
29+
import java.util.concurrent.atomic.AtomicInteger;
30+
import java.util.stream.Collectors;
31+
3232
/**
3333
* Abstract implementation of the {@link EmbeddingModel} interface that provides
3434
* dimensions calculation caching.
@@ -44,6 +44,15 @@ public abstract class AbstractEmbeddingModel implements EmbeddingModel {
4444

4545
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions();
4646

47+
static class Hints implements RuntimeHintsRegistrar {
48+
49+
@Override
50+
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
51+
hints.resources().registerResource(EMBEDDING_MODEL_DIMENSIONS_PROPERTIES);
52+
}
53+
54+
}
55+
4756
/**
4857
* Cached embedding dimensions.
4958
*/
@@ -97,13 +106,4 @@ public int dimensions() {
97106
return this.embeddingDimensions.get();
98107
}
99108

100-
static class Hints implements RuntimeHintsRegistrar {
101-
102-
@Override
103-
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
104-
hints.resources().registerResource(EMBEDDING_MODEL_DIMENSIONS_PROPERTIES);
105-
}
106-
107-
}
108-
109109
}

0 commit comments

Comments
 (0)