Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,6 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.vertex.api.VertexAiApi;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.aot.hint.TypeReference;
Expand All @@ -30,7 +17,9 @@
import java.util.stream.Collectors;

/***
* Native hints
* AOT hints (for GraalVM native images) for resources common to multiple modules across
* different dependencies. For integration-specific hints, see the respective auto
* configurations.
*
* @author Josh Long
*/
Expand All @@ -41,14 +30,13 @@ public class NativeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {

for (var h : Set.of(new BedrockAiHints(), new VertexAiHints(), new OpenAiHints(), new PdfReaderHints(),
new KnuddelsHints(), new OllamaHints()))
for (var h : Set.of(new PdfReaderHints(), new KnuddelsHints()))
h.registerHints(hints, classLoader);

hints.resources().registerResource(new ClassPathResource("embedding/embedding-model-dimensions.properties"));
}

private static Set<TypeReference> findJsonAnnotatedClasses(Class<?> packageClass) {
public static Set<TypeReference> findJsonAnnotatedClasses(Class<?> packageClass) {
var packageName = packageClass.getPackageName();
var classPathScanningCandidateComponentProvider = new ClassPathScanningCandidateComponentProvider(false);
classPathScanningCandidateComponentProvider.addIncludeFilter(new AnnotationTypeFilter(JsonInclude.class));
Expand All @@ -62,69 +50,16 @@ private static Set<TypeReference> findJsonAnnotatedClasses(Class<?> packageClass
.collect(Collectors.toUnmodifiableSet());
}

static class VertexAiHints implements RuntimeHintsRegistrar {

@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClasses(VertexAiApi.class))
hints.reflection().registerType(tr, mcs);
}

}

static class OllamaHints implements RuntimeHintsRegistrar {

@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClasses(OllamaApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(OllamaOptions.class))
hints.reflection().registerType(tr, mcs);
}

}

static class BedrockAiHints implements RuntimeHintsRegistrar {

@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClasses(Ai21Jurassic2ChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(CohereChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(CohereEmbeddingBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(Llama2ChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(TitanChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(TitanEmbeddingBedrockApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(AnthropicChatBedrockApi.class))
hints.reflection().registerType(tr, mcs);
}

}

static class OpenAiHints implements RuntimeHintsRegistrar {

@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClasses(OpenAiApi.class))
hints.reflection().registerType(tr, mcs);
}

}

static class KnuddelsHints implements RuntimeHintsRegistrar {

@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
hints.resources().registerResource(new ClassPathResource("/com/knuddels/jtokkit/cl100k_base.tiktoken"));
try {
hints.resources().registerResource(new ClassPathResource("/com/knuddels/jtokkit/cl100k_base.tiktoken"));
}
catch (Exception e) {
throw new RuntimeException(e);
}
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package org.springframework.ai.autoconfigure.bedrock;

import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.Set;

@Configuration
@ConditionalOnClass({ Ai21Jurassic2ChatBedrockApi.class, CohereChatBedrockApi.class, CohereEmbeddingBedrockApi.class,
Llama2ChatBedrockApi.class, TitanChatBedrockApi.class, TitanEmbeddingBedrockApi.class,
AnthropicChatBedrockApi.class })
class BedrockAotAutoConfiguration {

@Bean
static BedrockAiHints bedrockAiHints() {
return new BedrockAiHints();
}

static class BedrockAiHints implements BeanRegistrationAotProcessor {

@Override
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
return (generationContext, beanRegistrationCode) -> {
var hints = generationContext.getRuntimeHints();
var mcs = MemberCategory.values();
for (var c : Set.of(Ai21Jurassic2ChatBedrockApi.class, CohereChatBedrockApi.class,
CohereEmbeddingBedrockApi.class, Llama2ChatBedrockApi.class, TitanChatBedrockApi.class,
TitanEmbeddingBedrockApi.class, AnthropicChatBedrockApi.class))
hints.reflection().registerType(c, mcs);

};
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,27 @@
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.ai.ollama.OllamaEmbeddingClient;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
import org.springframework.beans.factory.aot.BeanRegistrationCode;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ImportRuntimeHints;

import static org.springframework.ai.autoconfigure.NativeHints.findJsonAnnotatedClasses;

/**
* {@link AutoConfiguration Auto-configuration} for Ollama Chat Client.
*
* @author Christian Tzolov
* @author Josh Long
* @since 0.8.0
*/
@AutoConfiguration
Expand All @@ -39,6 +49,28 @@
@ImportRuntimeHints(NativeHints.class)
public class OllamaAutoConfiguration {

@Bean
static OllamaHints ollamaHints() {
return new OllamaHints();
}

static class OllamaHints implements BeanRegistrationAotProcessor {

@Override
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {

return (generationContext, beanRegistrationCode) -> {
var mcs = MemberCategory.values();
var hints = generationContext.getRuntimeHints();
for (var tr : findJsonAnnotatedClasses(OllamaApi.class))
hints.reflection().registerType(tr, mcs);
for (var tr : findJsonAnnotatedClasses(OllamaOptions.class))
hints.reflection().registerType(tr, mcs);
};
}

}

@Bean
@ConditionalOnMissingBean
public OllamaApi ollamaApi(OllamaConnectionProperties properties) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiEmbeddingClient;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
import org.springframework.beans.factory.aot.BeanRegistrationCode;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
Expand All @@ -31,13 +37,34 @@
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestClient;

import static org.springframework.ai.autoconfigure.NativeHints.findJsonAnnotatedClasses;

@AutoConfiguration
@ConditionalOnClass(OpenAiApi.class)
@EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiChatProperties.class,
OpenAiEmbeddingProperties.class })
@ImportRuntimeHints(NativeHints.class)
public class OpenAiAutoConfiguration {

@Bean
static OpenAiHints openAiHints() {
return new OpenAiHints();
}

static class OpenAiHints implements BeanRegistrationAotProcessor {

@Override
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
return (generationContext, beanRegistrationCode) -> {
var mcs = MemberCategory.values();
var hints = generationContext.getRuntimeHints();
for (var tr : findJsonAnnotatedClasses(OpenAiApi.class))
hints.reflection().registerType(tr, mcs);
};
}

}

@Bean
@ConditionalOnMissingBean
public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProperties,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
package org.springframework.ai.autoconfigure.vertexai;

import org.springframework.ai.autoconfigure.NativeHints;
import org.springframework.ai.vertex.api.VertexAiApi;
import org.springframework.ai.vertex.VertexAiEmbeddingClient;
import org.springframework.ai.vertex.VertexAiChatClient;
import org.springframework.ai.vertex.VertexAiEmbeddingClient;
import org.springframework.ai.vertex.api.VertexAiApi;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
Expand All @@ -35,6 +39,25 @@
VertexAiEmbeddingProperties.class })
public class VertexAiAutoConfiguration {

@Bean
static VertexAiHints vertexAiHints() {
return new VertexAiHints();
}

static class VertexAiHints implements BeanRegistrationAotProcessor {

@Override
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
return (generationContext, beanRegistrationCode) -> {
var hints = generationContext.getRuntimeHints();
var mcs = MemberCategory.values();
for (var tr : NativeHints.findJsonAnnotatedClasses(VertexAiApi.class))
hints.reflection().registerType(tr, mcs);
};
}

}

@Bean
@ConditionalOnMissingBean
public VertexAiChatClient vertexAiChatClient(VertexAiApi vertexAiApi, VertexAiChatProperties chatProperties) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ org.springframework.ai.autoconfigure.vertexai.VertexAiAutoConfiguration
org.springframework.ai.autoconfigure.bedrock.llama2.BedrockLlama2ChatAutoConfiguration
org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereChatAutoConfiguration
org.springframework.ai.autoconfigure.bedrock.cohere.BedrockCohereEmbeddingAutoConfiguration
org.springframework.ai.autoconfigure.bedrock.BedrockAotAutoConfiguration
org.springframework.ai.autoconfigure.bedrock.anthropic.BedrockAnthropicChatAutoConfiguration
org.springframework.ai.autoconfigure.bedrock.titan.BedrockTitanChatAutoConfiguration
org.springframework.ai.autoconfigure.bedrock.titan.BedrockTitanEmbeddingAutoConfiguration
Expand Down