diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-replicate/pom.xml b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/pom.xml new file mode 100644 index 00000000000..cd285955d59 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/pom.xml @@ -0,0 +1,91 @@ + + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 1.1.0-SNAPSHOT + ../../../pom.xml + + spring-ai-autoconfigure-model-replicate + jar + Spring AI Replicate Auto Configuration + Spring AI Replicate Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + + org.springframework.ai + spring-ai-replicate + ${project.parent.version} + + + + + + org.springframework.ai + spring-ai-autoconfigure-retry + ${project.parent.version} + + + + + org.springframework.boot + spring-boot-starter + true + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + + org.springframework.boot + spring-boot-autoconfigure-processor + true + + + + + org.springframework.ai + spring-ai-test + ${project.parent.version} + test + + + + org.springframework.boot + spring-boot-starter-test + test + + + + diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateChatAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateChatAutoConfiguration.java new file mode 100644 index 00000000000..25a1d7e652f --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateChatAutoConfiguration.java @@ -0,0 +1,122 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.replicate.autoconfigure; + +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.replicate.ReplicateChatModel; +import org.springframework.ai.replicate.ReplicateMediaModel; +import org.springframework.ai.replicate.ReplicateStringModel; +import org.springframework.ai.replicate.ReplicateStructuredModel; +import org.springframework.ai.replicate.api.ReplicateApi; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; + +/** + * {@link AutoConfiguration Auto-configuration} for Replicate models. + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +@AutoConfiguration(after = RestClientAutoConfiguration.class) +@ConditionalOnClass(ReplicateApi.class) +@EnableConfigurationProperties({ ReplicateConnectionProperties.class, ReplicateChatProperties.class, + ReplicateMediaProperties.class, ReplicateStringProperties.class, ReplicateStructuredProperties.class }) +@ConditionalOnProperty(prefix = ReplicateConnectionProperties.CONFIG_PREFIX, name = "api-token") +public class ReplicateChatAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + @ConditionalOnProperty(prefix = ReplicateConnectionProperties.CONFIG_PREFIX, name = "api-token") + public ReplicateApi replicateApi(ReplicateConnectionProperties connectionProperties, + ObjectProvider restClientBuilderProvider, + ObjectProvider responseErrorHandlerProvider) { + + if (!StringUtils.hasText(connectionProperties.getApiToken())) { + throw new IllegalArgumentException( + "Replicate API token must be configured via spring.ai.replicate.api-token"); + } + + var builder = ReplicateApi.builder() + .apiKey(connectionProperties.getApiToken()) + .baseUrl(connectionProperties.getBaseUrl()); + + RestClient.Builder restClientBuilder = restClientBuilderProvider.getIfAvailable(RestClient::builder); + if (restClientBuilder != null) { + builder.restClientBuilder(restClientBuilder); + } + + ResponseErrorHandler errorHandler = responseErrorHandlerProvider.getIfAvailable(); + if (errorHandler != null) { + builder.responseErrorHandler(errorHandler); + } + + return builder.build(); + } + + @Bean + @ConditionalOnMissingBean + public ReplicateChatModel replicateChatModel(ReplicateApi replicateApi, ReplicateChatProperties chatProperties, + ObjectProvider observationRegistry) { + return ReplicateChatModel.builder() + .replicateApi(replicateApi) + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .defaultOptions(chatProperties.getOptions()) + .build(); + } + + @Bean + @ConditionalOnMissingBean + public ReplicateMediaModel replicateMediaModel(ReplicateApi replicateApi, + ReplicateMediaProperties mediaProperties) { + return ReplicateMediaModel.builder() + .replicateApi(replicateApi) + .defaultOptions(mediaProperties.getOptions()) + .build(); + } + + @Bean + @ConditionalOnMissingBean + public ReplicateStringModel replicateStringModel(ReplicateApi replicateApi, + ReplicateStringProperties stringProperties) { + return ReplicateStringModel.builder() + .replicateApi(replicateApi) + .defaultOptions(stringProperties.getOptions()) + .build(); + } + + @Bean + @ConditionalOnMissingBean + public ReplicateStructuredModel replicateStructuredModel(ReplicateApi replicateApi, + ReplicateStructuredProperties structuredProperties) { + + return ReplicateStructuredModel.builder() + .replicateApi(replicateApi) + .defaultOptions(structuredProperties.getOptions()) + .build(); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateChatProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateChatProperties.java new file mode 100644 index 00000000000..3acccf2be4e --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateChatProperties.java @@ -0,0 +1,45 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.replicate.autoconfigure; + +import org.springframework.ai.replicate.ReplicateChatOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * Chat properties for Replicate AI. + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +@ConfigurationProperties(ReplicateChatProperties.CONFIG_PREFIX) +public class ReplicateChatProperties { + + public static final String CONFIG_PREFIX = "spring.ai.replicate.chat"; + + @NestedConfigurationProperty + private ReplicateChatOptions options = ReplicateChatOptions.builder().build(); + + public ReplicateChatOptions getOptions() { + return this.options; + } + + public void setOptions(ReplicateChatOptions options) { + this.options = options; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateConnectionProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateConnectionProperties.java new file mode 100644 index 00000000000..deaed7914d8 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateConnectionProperties.java @@ -0,0 +1,54 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.replicate.autoconfigure; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Connection properties for Replicate AI. + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +@ConfigurationProperties(ReplicateConnectionProperties.CONFIG_PREFIX) +public class ReplicateConnectionProperties { + + public static final String CONFIG_PREFIX = "spring.ai.replicate"; + + public static final String DEFAULT_BASE_URL = "https://api.replicate.com/v1"; + + private String apiToken; + + private String baseUrl = DEFAULT_BASE_URL; + + public String getApiToken() { + return this.apiToken; + } + + public void setApiToken(String apiToken) { + this.apiToken = apiToken; + } + + public String getBaseUrl() { + return this.baseUrl; + } + + public void setBaseUrl(String baseUrl) { + this.baseUrl = baseUrl; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateMediaProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateMediaProperties.java new file mode 100644 index 00000000000..4a606428401 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateMediaProperties.java @@ -0,0 +1,46 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.replicate.autoconfigure; + +import org.springframework.ai.replicate.ReplicateOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * Media model properties for Replicate AI. Used for image, video, and audio generation + * models. + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +@ConfigurationProperties(ReplicateMediaProperties.CONFIG_PREFIX) +public class ReplicateMediaProperties { + + public static final String CONFIG_PREFIX = "spring.ai.replicate.media"; + + @NestedConfigurationProperty + private ReplicateOptions options = ReplicateOptions.builder().build(); + + public ReplicateOptions getOptions() { + return this.options; + } + + public void setOptions(ReplicateOptions options) { + this.options = options; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateStringProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateStringProperties.java new file mode 100644 index 00000000000..416ee1f816f --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateStringProperties.java @@ -0,0 +1,46 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.replicate.autoconfigure; + +import org.springframework.ai.replicate.ReplicateOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * String model properties for Replicate AI. Used for models that return simple string + * outputs like classifiers and filters. + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +@ConfigurationProperties(ReplicateStringProperties.CONFIG_PREFIX) +public class ReplicateStringProperties { + + public static final String CONFIG_PREFIX = "spring.ai.replicate.string"; + + @NestedConfigurationProperty + private ReplicateOptions options = ReplicateOptions.builder().build(); + + public ReplicateOptions getOptions() { + return this.options; + } + + public void setOptions(ReplicateOptions options) { + this.options = options; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateStructuredProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateStructuredProperties.java new file mode 100644 index 00000000000..8eb3f84df6f --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateStructuredProperties.java @@ -0,0 +1,46 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.replicate.autoconfigure; + +import org.springframework.ai.replicate.ReplicateOptions; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * Structured model properties for Replicate AI. Used for models that return structured + * JSON objects with multiple fields. + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +@ConfigurationProperties(ReplicateStructuredProperties.CONFIG_PREFIX) +public class ReplicateStructuredProperties { + + public static final String CONFIG_PREFIX = "spring.ai.replicate.structured"; + + @NestedConfigurationProperty + private ReplicateOptions options = ReplicateOptions.builder().build(); + + public ReplicateOptions getOptions() { + return this.options; + } + + public void setOptions(ReplicateOptions options) { + this.options = options; + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 00000000000..bde6423d44e --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1 @@ +org.springframework.ai.model.replicate.autoconfigure.ReplicateChatAutoConfiguration diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/test/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateChatAutoConfigurationIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/test/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateChatAutoConfigurationIT.java new file mode 100644 index 00000000000..2c25003f811 --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/test/java/org/springframework/ai/model/replicate/autoconfigure/ReplicateChatAutoConfigurationIT.java @@ -0,0 +1,140 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.replicate.autoconfigure; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.replicate.ReplicateChatModel; +import org.springframework.ai.replicate.ReplicateMediaModel; +import org.springframework.ai.replicate.ReplicateStringModel; +import org.springframework.ai.replicate.ReplicateStructuredModel; +import org.springframework.ai.replicate.api.ReplicateApi; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link ReplicateChatAutoConfiguration}. + * + * @author Rene Maierhofer + */ +@EnabledIfEnvironmentVariable(named = "REPLICATE_API_TOKEN", matches = ".+") +class ReplicateChatAutoConfigurationIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.replicate.api-token=" + System.getenv("REPLICATE_API_TOKEN")) + .withConfiguration( + AutoConfigurations.of(RestClientAutoConfiguration.class, ReplicateChatAutoConfiguration.class)); + + @Test + void testReplicateApiBean() { + this.contextRunner.run(context -> { + assertThat(context).hasSingleBean(ReplicateApi.class); + ReplicateApi api = context.getBean(ReplicateApi.class); + assertThat(api).isNotNull(); + }); + } + + @Test + void testReplicateChatModelBean() { + this.contextRunner.withPropertyValues("spring.ai.replicate.chat.options.model=meta/meta-llama-3-8b-instruct") + .run(context -> { + assertThat(context).hasSingleBean(ReplicateChatModel.class); + ReplicateChatModel chatModel = context.getBean(ReplicateChatModel.class); + assertThat(chatModel).isNotNull(); + + String response = chatModel.call("Say hello"); + assertThat(response).isNotEmpty(); + }); + } + + @Test + void testReplicateMediaModelBean() { + this.contextRunner.withPropertyValues("spring.ai.replicate.media.options.model=black-forest-labs/flux-schnell") + .run(context -> { + assertThat(context).hasSingleBean(ReplicateMediaModel.class); + ReplicateMediaModel mediaModel = context.getBean(ReplicateMediaModel.class); + assertThat(mediaModel).isNotNull(); + }); + } + + @Test + void testReplicateStringModelBean() { + this.contextRunner + .withPropertyValues("spring.ai.replicate.string.options.model=falcons-ai/nsfw_image_detection") + .run(context -> { + assertThat(context).hasSingleBean(ReplicateStringModel.class); + ReplicateStringModel stringModel = context.getBean(ReplicateStringModel.class); + assertThat(stringModel).isNotNull(); + }); + } + + @Test + void testReplicateStructuredModelBean() { + this.contextRunner.withPropertyValues("spring.ai.replicate.structured.options.model=openai/clip") + .run(context -> { + assertThat(context).hasSingleBean(ReplicateStructuredModel.class); + ReplicateStructuredModel structuredModel = context.getBean(ReplicateStructuredModel.class); + assertThat(structuredModel).isNotNull(); + }); + } + + @Test + void testAllModelBeansCreated() { + this.contextRunner + .withPropertyValues("spring.ai.replicate.chat.options.model=meta/meta-llama-3-8b-instruct", + "spring.ai.replicate.media.options.model=black-forest-labs/flux-schnell", + "spring.ai.replicate.string.options.model=falcons-ai/nsfw_image_detection", + "spring.ai.replicate.structured.options.model=openai/clip") + .run(context -> { + assertThat(context).hasSingleBean(ReplicateApi.class); + assertThat(context).hasSingleBean(ReplicateChatModel.class); + assertThat(context).hasSingleBean(ReplicateMediaModel.class); + assertThat(context).hasSingleBean(ReplicateStringModel.class); + assertThat(context).hasSingleBean(ReplicateStructuredModel.class); + }); + } + + @Test + void testChatInputParameters() { + this.contextRunner + .withPropertyValues("spring.ai.replicate.chat.options.model=meta/meta-llama-3-8b-instruct", + "spring.ai.replicate.chat.options.input.temperature=0.7", + "spring.ai.replicate.chat.options.input.max_tokens=50") + .run(context -> { + ReplicateChatModel chatModel = context.getBean(ReplicateChatModel.class); + assertThat(chatModel).isNotNull(); + Prompt prompt = new Prompt("Write a very long poem."); + ChatResponse response = chatModel.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getResults()).isNotEmpty(); + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + ChatResponseMetadata metadata = response.getMetadata(); + Usage usage = metadata.getUsage(); + assertThat(usage.getCompletionTokens()).isLessThanOrEqualTo(50); + }); + } + +} diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/test/java/org/springframework/ai/model/replicate/autoconfigure/ReplicatePropertiesTests.java b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/test/java/org/springframework/ai/model/replicate/autoconfigure/ReplicatePropertiesTests.java new file mode 100644 index 00000000000..db2e0dcba0b --- /dev/null +++ b/auto-configurations/models/spring-ai-autoconfigure-model-replicate/src/test/java/org/springframework/ai/model/replicate/autoconfigure/ReplicatePropertiesTests.java @@ -0,0 +1,138 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.replicate.autoconfigure; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.replicate.ReplicateChatOptions; +import org.springframework.ai.replicate.ReplicateOptions; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for Replicate configuration properties. + * + * @author Rene Maierhofer + */ +class ReplicatePropertiesTests { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(ReplicateChatAutoConfiguration.class)); + + @Test + void testConnectionPropertiesBinding() { + this.contextRunner + .withPropertyValues("spring.ai.replicate.api-token=test-token", + "spring.ai.replicate.base-url=https://127.0.0.1/v1") + .run(context -> { + ReplicateConnectionProperties properties = context.getBean(ReplicateConnectionProperties.class); + assertThat(properties.getApiToken()).isEqualTo("test-token"); + assertThat(properties.getBaseUrl()).isEqualTo("https://127.0.0.1/v1"); + }); + } + + @Test + void testConnectionPropertiesDefaults() { + this.contextRunner.withPropertyValues("spring.ai.replicate.api-token=test-token").run(context -> { + ReplicateConnectionProperties properties = context.getBean(ReplicateConnectionProperties.class); + assertThat(properties.getBaseUrl()).isEqualTo(ReplicateConnectionProperties.DEFAULT_BASE_URL); + }); + } + + @Test + void testChatPropertiesWithInputParameters() { + this.contextRunner + .withPropertyValues("spring.ai.replicate.api-token=test-token", + "spring.ai.replicate.chat.options.model=meta/meta-llama-3-8b-instruct", + "spring.ai.replicate.chat.options.input.temperature=0.7", + "spring.ai.replicate.chat.options.input.max_tokens=100", + "spring.ai.replicate.chat.options.input.enabled=true") + .run(context -> { + ReplicateChatProperties properties = context.getBean(ReplicateChatProperties.class); + ReplicateChatOptions options = properties.getOptions(); + + assertThat(options.getInput()).isNotEmpty(); + assertThat(options.getInput().get("temperature")).isInstanceOf(Double.class).isEqualTo(0.7); + assertThat(options.getInput().get("max_tokens")).isInstanceOf(Integer.class).isEqualTo(100); + assertThat(options.getInput().get("enabled")).isInstanceOf(Boolean.class).isEqualTo(true); + }); + } + + @Test + void testMediaPropertiesBinding() { + this.contextRunner + .withPropertyValues("spring.ai.replicate.api-token=test-token", + "spring.ai.replicate.media.options.model=black-forest-labs/flux-schnell", + "spring.ai.replicate.media.options.version=media-version") + .run(context -> { + ReplicateMediaProperties properties = context.getBean(ReplicateMediaProperties.class); + ReplicateOptions options = properties.getOptions(); + + assertThat(options).isNotNull(); + assertThat(options.getModel()).isEqualTo("black-forest-labs/flux-schnell"); + assertThat(options.getVersion()).isEqualTo("media-version"); + }); + } + + @Test + void testMediaPropertiesWithInputParameters() { + this.contextRunner + .withPropertyValues("spring.ai.replicate.api-token=test-token", + "spring.ai.replicate.media.options.model=black-forest-labs/flux-schnell", + "spring.ai.replicate.media.options.input.prompt=test prompt", + "spring.ai.replicate.media.options.input.num_outputs=2") + .run(context -> { + ReplicateMediaProperties properties = context.getBean(ReplicateMediaProperties.class); + ReplicateOptions options = properties.getOptions(); + + assertThat(options.getInput()).isNotEmpty(); + assertThat(options.getInput().get("prompt")).isEqualTo("test prompt"); + assertThat(options.getInput().get("num_outputs")).isInstanceOf(Integer.class).isEqualTo(2); + }); + } + + @Test + void testStringPropertiesBinding() { + this.contextRunner + .withPropertyValues("spring.ai.replicate.api-token=test-token", + "spring.ai.replicate.string.options.model=falcons-ai/nsfw_image_detection") + .run(context -> { + ReplicateStringProperties properties = context.getBean(ReplicateStringProperties.class); + ReplicateOptions options = properties.getOptions(); + + assertThat(options).isNotNull(); + assertThat(options.getModel()).isEqualTo("falcons-ai/nsfw_image_detection"); + }); + } + + @Test + void testStructuredPropertiesBinding() { + this.contextRunner + .withPropertyValues("spring.ai.replicate.api-token=test-token", + "spring.ai.replicate.structured.options.model=openai/clip") + .run(context -> { + ReplicateStructuredProperties properties = context.getBean(ReplicateStructuredProperties.class); + ReplicateOptions options = properties.getOptions(); + + assertThat(options).isNotNull(); + assertThat(options.getModel()).isEqualTo("openai/clip"); + }); + } + +} diff --git a/models/spring-ai-replicate/README.md b/models/spring-ai-replicate/README.md new file mode 100644 index 00000000000..e9f5ebc9bda --- /dev/null +++ b/models/spring-ai-replicate/README.md @@ -0,0 +1,4 @@ +# Spring AI Replicate +Spring AI integration for [Replicate](https://replicate.com/) + +Replicate provides access to various models through a unified API. diff --git a/models/spring-ai-replicate/pom.xml b/models/spring-ai-replicate/pom.xml new file mode 100644 index 00000000000..34a4a60de5c --- /dev/null +++ b/models/spring-ai-replicate/pom.xml @@ -0,0 +1,85 @@ + + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + spring-ai-replicate + jar + Spring AI Model - Replicate + Replicate AI models support + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + + org.springframework.ai + spring-ai-model + ${project.parent.version} + + + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + org.springframework + spring-context-support + + + + org.springframework + spring-webflux + + + + org.slf4j + slf4j-api + + + + + org.springframework.ai + spring-ai-test + ${project.version} + test + + + + io.micrometer + micrometer-observation-test + test + + + + + diff --git a/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateChatModel.java b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateChatModel.java new file mode 100644 index 00000000000..017846a7f2a --- /dev/null +++ b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateChatModel.java @@ -0,0 +1,313 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.MessageAggregator; +import org.springframework.ai.chat.observation.ChatModelObservationContext; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.chat.observation.ChatModelObservationDocumentation; +import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.replicate.api.ReplicateApi; +import org.springframework.ai.replicate.api.ReplicateApi.PredictionRequest; +import org.springframework.util.Assert; + +/** + * Replicate Chat Model implementation. + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +public class ReplicateChatModel implements ChatModel { + + private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final ReplicateApi replicateApi; + + private final ObservationRegistry observationRegistry; + + private final ReplicateChatOptions defaultOptions; + + private final ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + + public ReplicateChatModel(ReplicateApi replicateApi, ObservationRegistry observationRegistry, + ReplicateChatOptions defaultOptions) { + Assert.notNull(replicateApi, "replicateApi must not be null"); + Assert.notNull(observationRegistry, "observationRegistry must not be null"); + this.replicateApi = replicateApi; + this.observationRegistry = observationRegistry; + this.defaultOptions = defaultOptions; + } + + @Override + public ChatResponse call(Prompt prompt) { + return this.internalCall(prompt); + } + + @Override + public Flux stream(Prompt prompt) { + return this.internalStream(prompt); + } + + private ChatResponse internalCall(Prompt prompt) { + // Replicate does not support conversation history. + assert prompt.getUserMessages().size() == 1; + ReplicateChatOptions promptOptions = (ReplicateChatOptions) prompt.getOptions(); + ReplicateChatOptions requestOptions = mergeOptions(promptOptions); + PredictionRequest request = createRequestWithOptions(prompt, requestOptions, false); + + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(ReplicateApi.PROVIDER_NAME) + .build(); + + return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + ReplicateApi.PredictionResponse predictionResponse = this.replicateApi + .createPredictionAndWait(requestOptions.getModel(), request); + + if (predictionResponse == null) { + logger.warn("No prediction response returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + Map metadata = buildMetadataMap(predictionResponse); + + String content = extractContentFromOutput(predictionResponse.output()); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content(content) + .properties(metadata) + .build(); + Generation generation = new Generation(assistantMessage); + DefaultUsage usage = getDefaultUsage(predictionResponse.metrics()); + ChatResponse chatResponse = new ChatResponse(List.of(generation), from(predictionResponse, usage)); + observationContext.setResponse(chatResponse); + + return chatResponse; + }); + } + + private static ChatResponseMetadata from(ReplicateApi.PredictionResponse result, Usage usage) { + return ChatResponseMetadata.builder() + .id(result.id()) + .model(result.model()) + .usage(usage) + .keyValue("created", result.createdAt()) + .keyValue("version", result.version()) + .build(); + } + + private static DefaultUsage getDefaultUsage(ReplicateApi.Metrics metrics) { + if (metrics == null) { + return new DefaultUsage(0, 0); + } + Integer inputTokens = metrics.inputTokenCount() != null ? metrics.inputTokenCount() : 0; + Integer outputTokens = metrics.outputTokenCount() != null ? metrics.outputTokenCount() : 0; + return new DefaultUsage(inputTokens, outputTokens); + } + + private Flux internalStream(Prompt prompt) { + return Flux.deferContextual(contextView -> { + ReplicateChatOptions promptOptions = (ReplicateChatOptions) prompt.getOptions(); + ReplicateChatOptions requestOptions = mergeOptions(promptOptions); + PredictionRequest request = createRequestWithOptions(prompt, requestOptions, true); + + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() + .prompt(prompt) + .provider(ReplicateApi.PROVIDER_NAME) + .build(); + + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry); + + observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); + + Flux responseStream = this.replicateApi + .createPredictionStream(requestOptions.getModel(), request); + + Flux chatResponseFlux = responseStream.map(chunk -> { + String content = extractContentFromOutput(chunk.output()); + + AssistantMessage assistantMessage = AssistantMessage.builder() + .content(content) + .properties(buildMetadataMap(chunk)) + .build(); + + Generation generation = new Generation(assistantMessage); + DefaultUsage usage = getDefaultUsage(chunk.metrics()); + return new ChatResponse(List.of(generation), from(chunk, usage)); + }); + + // @formatter:off + return new MessageAggregator() + .aggregate(chatResponseFlux, observationContext::setResponse) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + // @formatter:on + }); + } + + /** + * Merges default options from properties with prompt options. Prompt options take + * precedence + * @param promptOptions Options from the current Prompt + * @return merged Options + */ + private ReplicateChatOptions mergeOptions(ReplicateChatOptions promptOptions) { + if (this.defaultOptions == null) { + return promptOptions != null ? promptOptions : ReplicateChatOptions.builder().build(); + } + if (promptOptions == null) { + return this.defaultOptions; + } + ReplicateChatOptions merged = ReplicateChatOptions.fromOptions(this.defaultOptions); + if (promptOptions.getModel() != null) { + merged.setModel(promptOptions.getModel()); + } + if (promptOptions.getVersion() != null) { + merged.setVersion(promptOptions.getVersion()); + } + if (promptOptions.getWebhook() != null) { + merged.setWebhook(promptOptions.getWebhook()); + } + if (promptOptions.getWebhookEventsFilter() != null) { + merged.setWebhookEventsFilter(promptOptions.getWebhookEventsFilter()); + } + Map mergedInput = new HashMap<>(); + if (this.defaultOptions.getInput() != null) { + mergedInput.putAll(this.defaultOptions.getInput()); + } + if (promptOptions.getInput() != null) { + mergedInput.putAll(promptOptions.getInput()); + } + merged.setInput(mergedInput); + + return merged; + } + + private PredictionRequest createRequestWithOptions(Prompt prompt, ReplicateChatOptions requestOptions, + boolean stream) { + Map input = new HashMap<>(); + if (requestOptions.getInput() != null) { + input.putAll(requestOptions.getInput()); + } + input.put("prompt", prompt.getUserMessage().getText()); + return new PredictionRequest(requestOptions.getVersion(), input, requestOptions.getWebhook(), + requestOptions.getWebhookEventsFilter(), stream); + } + + private Map buildMetadataMap(ReplicateApi.PredictionResponse response) { + Map metadata = new HashMap<>(); + if (response.id() != null) { + metadata.put("id", response.id()); + } + if (response.urls() != null) { + metadata.put("urls", response.urls()); + } + if (response.error() != null) { + metadata.put("error", response.error()); + } + if (response.logs() != null) { + metadata.put("logs", response.logs()); + } + return metadata; + } + + /** + * Extracts content from the output object. The output can be either a String or a + * List of Strings, depending on the model being used. + * @param output The output object from the prediction response + * @return The extracted content as a String, or empty string if null + */ + private static String extractContentFromOutput(Object output) { + if (output == null) { + return ""; + } + if (output instanceof String stringOutput) { + return stringOutput; + } + if (output instanceof List outputList) { + if (outputList.isEmpty()) { + return ""; + } + return outputList.stream().map(Object::toString).reduce("", (a, b) -> a + b); + } + // Fallback to toString for other types + return output.toString(); + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private ReplicateApi replicateApi; + + private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; + + private ReplicateChatOptions defaultOptions; + + private Builder() { + } + + public Builder replicateApi(ReplicateApi replicateApi) { + this.replicateApi = replicateApi; + return this; + } + + public Builder observationRegistry(ObservationRegistry observationRegistry) { + this.observationRegistry = observationRegistry; + return this; + } + + public Builder defaultOptions(ReplicateChatOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + public ReplicateChatModel build() { + return new ReplicateChatModel(this.replicateApi, this.observationRegistry, this.defaultOptions); + } + + } + +} diff --git a/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateChatOptions.java b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateChatOptions.java new file mode 100644 index 00000000000..53609ba776f --- /dev/null +++ b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateChatOptions.java @@ -0,0 +1,240 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.chat.prompt.ChatOptions; + +/** + * Base options for Replicate models. Contains common fields that apply to all Replicate + * models regardless of type (chat, image, audio, etc.). + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ReplicateChatOptions implements ChatOptions { + + @JsonProperty("model") + protected String model; + + @JsonProperty("version") + protected String version; + + @JsonProperty("input") + protected Map input = new HashMap<>(); + + @JsonProperty("webhook") + protected String webhook; + + @JsonProperty("webhook_events_filter") + protected List webhookEventsFilter; + + public ReplicateChatOptions() { + } + + protected ReplicateChatOptions(Builder builder) { + this.model = builder.model; + this.version = builder.version; + this.input = builder.input != null ? new HashMap<>(builder.input) : new HashMap<>(); + this.webhook = builder.webhook; + this.webhookEventsFilter = builder.webhookEventsFilter; + } + + /** + * Add a custom parameter to the model input + */ + public ReplicateChatOptions withParameter(String key, Object value) { + this.input.put(key, value); + return this; + } + + /** + * Add multiple parameters to the model input + */ + public ReplicateChatOptions withParameters(Map parameters) { + this.input.putAll(parameters); + return this; + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Create a new ReplicateOptions from existing options + */ + public static ReplicateChatOptions fromOptions(ReplicateChatOptions fromOptions) { + return builder().model(fromOptions.getModel()) + .version(fromOptions.getVersion()) + .input(new HashMap<>(fromOptions.getInput())) + .webhook(fromOptions.getWebhook()) + .webhookEventsFilter(fromOptions.getWebhookEventsFilter()) + .build(); + } + + public String getModel() { + return this.model; + } + + @Override + @com.fasterxml.jackson.annotation.JsonIgnore + public Double getFrequencyPenalty() { + return null; + } + + @Override + @com.fasterxml.jackson.annotation.JsonIgnore + public Integer getMaxTokens() { + return null; + } + + @Override + @com.fasterxml.jackson.annotation.JsonIgnore + public Double getPresencePenalty() { + return null; + } + + @Override + @com.fasterxml.jackson.annotation.JsonIgnore + public List getStopSequences() { + return null; + } + + @Override + @com.fasterxml.jackson.annotation.JsonIgnore + public Double getTemperature() { + return null; + } + + @Override + @com.fasterxml.jackson.annotation.JsonIgnore + public Integer getTopK() { + return null; + } + + @Override + @com.fasterxml.jackson.annotation.JsonIgnore + public Double getTopP() { + return null; + } + + @Override + @SuppressWarnings("unchecked") + public T copy() { + return (T) fromOptions(this); + } + + public void setModel(String model) { + this.model = model; + } + + public String getVersion() { + return this.version; + } + + public void setVersion(String version) { + this.version = version; + } + + public Map getInput() { + return this.input; + } + + public void setInput(Map input) { + this.input = ReplicateOptionsUtils.convertMapValues(input); + } + + public String getWebhook() { + return this.webhook; + } + + public void setWebhook(String webhook) { + this.webhook = webhook; + } + + public List getWebhookEventsFilter() { + return this.webhookEventsFilter; + } + + public void setWebhookEventsFilter(List webhookEventsFilter) { + this.webhookEventsFilter = webhookEventsFilter; + } + + public static class Builder { + + protected String model; + + protected String version; + + protected Map input = new HashMap<>(); + + protected String webhook; + + protected List webhookEventsFilter; + + protected Builder() { + } + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder version(String version) { + this.version = version; + return this; + } + + private Builder input(Map input) { + this.input = input; + return this; + } + + public Builder withParameter(String key, Object value) { + this.input.put(key, value); + return this; + } + + public Builder withParameters(Map params) { + this.input.putAll(params); + return this; + } + + public Builder webhook(String webhook) { + this.webhook = webhook; + return this; + } + + public Builder webhookEventsFilter(List webhookEventsFilter) { + this.webhookEventsFilter = webhookEventsFilter; + return this; + } + + public ReplicateChatOptions build() { + return new ReplicateChatOptions(this); + } + + } + +} diff --git a/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateMediaModel.java b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateMediaModel.java new file mode 100644 index 00000000000..f83ab8fac90 --- /dev/null +++ b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateMediaModel.java @@ -0,0 +1,208 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.replicate.api.ReplicateApi; +import org.springframework.ai.replicate.api.ReplicateApi.PredictionRequest; +import org.springframework.ai.replicate.api.ReplicateApi.PredictionResponse; +import org.springframework.util.Assert; + +/** + * Replicate Media Model implementation for image, video, and audio generation. Handles + * both single URI outputs and multiple URI outputs (arrays). + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +public class ReplicateMediaModel { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final ReplicateApi replicateApi; + + private final ReplicateOptions defaultOptions; + + public ReplicateMediaModel(ReplicateApi replicateApi, ReplicateOptions defaultOptions) { + Assert.notNull(replicateApi, "replicateApi must not be null"); + this.replicateApi = replicateApi; + this.defaultOptions = defaultOptions; + } + + /** + * Generate media (image/video/audio) using the specified model and options. + * @param options The model configuration including model name and input. + * @return Response containing URIs to generated media files + */ + public MediaResponse generate(ReplicateOptions options) { + ReplicateOptions mergedOptions = mergeOptions(options); + Assert.hasText(mergedOptions.getModel(), "model name must not be empty"); + + PredictionRequest request = new PredictionRequest(mergedOptions.getVersion(), mergedOptions.getInput(), + mergedOptions.getWebhook(), mergedOptions.getWebhookEventsFilter(), false); + + PredictionResponse predictionResponse = this.replicateApi.createPredictionAndWait(mergedOptions.getModel(), + request); + + if (predictionResponse == null) { + logger.warn("No prediction response returned for model: {}", mergedOptions.getModel()); + return new MediaResponse(Collections.emptyList(), predictionResponse); + } + + List uris = parseMediaOutput(predictionResponse.output()); + return new MediaResponse(uris, predictionResponse); + } + + /** + * Merges default options from properties with prompt options. Prompt options take + * precedence + * @param providedOptions Options from the current Prompt + * @return merged Options + */ + private ReplicateOptions mergeOptions(ReplicateOptions providedOptions) { + if (this.defaultOptions == null) { + return providedOptions != null ? providedOptions : ReplicateOptions.builder().build(); + } + + if (providedOptions == null) { + return this.defaultOptions; + } + ReplicateOptions merged = ReplicateOptions.fromOptions(this.defaultOptions); + if (providedOptions.getModel() != null) { + merged.setModel(providedOptions.getModel()); + } + if (providedOptions.getVersion() != null) { + merged.setVersion(providedOptions.getVersion()); + } + if (providedOptions.getWebhook() != null) { + merged.setWebhook(providedOptions.getWebhook()); + } + if (providedOptions.getWebhookEventsFilter() != null) { + merged.setWebhookEventsFilter(providedOptions.getWebhookEventsFilter()); + } + Map mergedInput = new HashMap<>(); + if (this.defaultOptions.getInput() != null) { + mergedInput.putAll(this.defaultOptions.getInput()); + } + if (providedOptions.getInput() != null) { + mergedInput.putAll(providedOptions.getInput()); + } + merged.setInput(mergedInput); + + return merged; + } + + /** + * Parse the output field which can be either a single URI string or an array of URI + * strings. + */ + private List parseMediaOutput(Object output) { + if (output == null) { + return Collections.emptyList(); + } + if (output instanceof String outputString) { + return List.of(outputString); + } + if (output instanceof List) { + List list = (List) output; + List uris = new ArrayList<>(); + for (Object item : list) { + if (item instanceof String itemString) { + uris.add(itemString); + } + } + return uris; + } + logger.warn("Unexpected output type: {}", output.getClass().getName()); + return Collections.emptyList(); + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Response containing generated media URIs and the raw prediction response. + */ + public static class MediaResponse { + + private final List uris; + + private final PredictionResponse predictionResponse; + + public MediaResponse(List uris, PredictionResponse predictionResponse) { + this.uris = uris != null ? Collections.unmodifiableList(uris) : Collections.emptyList(); + this.predictionResponse = predictionResponse; + } + + /** + * Get the list of URIs pointing to generated media files. + */ + public List getUris() { + return this.uris; + } + + /** + * Get the first URI if available, useful for single-file outputs. + */ + public String getFirstUri() { + return this.uris.isEmpty() ? null : this.uris.get(0); + } + + /** + * Get the raw prediction response from Replicate API. + */ + public PredictionResponse getPredictionResponse() { + return this.predictionResponse; + } + + } + + public static final class Builder { + + private ReplicateApi replicateApi; + + private ReplicateOptions defaultOptions; + + private Builder() { + } + + public Builder replicateApi(ReplicateApi replicateApi) { + this.replicateApi = replicateApi; + return this; + } + + public Builder defaultOptions(ReplicateOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + public ReplicateMediaModel build() { + return new ReplicateMediaModel(this.replicateApi, this.defaultOptions); + } + + } + +} diff --git a/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateOptions.java b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateOptions.java new file mode 100644 index 00000000000..a83f23297e4 --- /dev/null +++ b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateOptions.java @@ -0,0 +1,207 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.model.ModelOptions; + +/** + * Base options for Replicate models. Contains common fields that apply to all Replicate + * models regardless of type (chat, image, audio, etc.). + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ReplicateOptions implements ModelOptions { + + /** + * The model identifier in format "owner/model-name" (e.g., "meta/llama-2-70b-chat") + */ + protected String model; + + /** + * The specific version hash of the model to use. Not mandatory for "official" models. + */ + @JsonProperty("version") + protected String version; + + /** + * Flexible input map containing model-specific parameters. This allows support for + * any model on Replicate, regardless of its specific input schema. + */ + @JsonProperty("input") + protected Map input = new HashMap<>(); + + /** + * Optional webhook URL for async notifications + */ + @JsonProperty("webhook") + protected String webhook; + + /** + * Optional webhook events to subscribe to + */ + @JsonProperty("webhook_events_filter") + protected List webhookEventsFilter; + + public ReplicateOptions() { + } + + protected ReplicateOptions(Builder builder) { + this.model = builder.model; + this.version = builder.version; + this.input = builder.input != null ? new HashMap<>(builder.input) : new HashMap<>(); + this.webhook = builder.webhook; + this.webhookEventsFilter = builder.webhookEventsFilter; + } + + /** + * Add a custom parameter to the model input + */ + public ReplicateOptions withParameter(String key, Object value) { + this.input.put(key, value); + return this; + } + + /** + * Add multiple parameters to the model input + */ + public ReplicateOptions withParameters(Map parameters) { + this.input.putAll(parameters); + return this; + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Create a new ReplicateOptions from existing options + */ + public static ReplicateOptions fromOptions(ReplicateOptions fromOptions) { + return builder().model(fromOptions.getModel()) + .version(fromOptions.getVersion()) + .input(new HashMap<>(fromOptions.getInput())) + .webhook(fromOptions.getWebhook()) + .webhookEventsFilter(fromOptions.getWebhookEventsFilter()) + .build(); + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getVersion() { + return this.version; + } + + public void setVersion(String version) { + this.version = version; + } + + public Map getInput() { + return this.input; + } + + public void setInput(Map input) { + this.input = ReplicateOptionsUtils.convertMapValues(input); + } + + public String getWebhook() { + return this.webhook; + } + + public void setWebhook(String webhook) { + this.webhook = webhook; + } + + public List getWebhookEventsFilter() { + return this.webhookEventsFilter; + } + + public void setWebhookEventsFilter(List webhookEventsFilter) { + this.webhookEventsFilter = webhookEventsFilter; + } + + public static class Builder { + + protected String model; + + protected String version; + + protected Map input = new HashMap<>(); + + protected String webhook; + + protected List webhookEventsFilter; + + protected Builder() { + } + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder version(String version) { + this.version = version; + return this; + } + + private Builder input(Map input) { + this.input = input; + return this; + } + + public Builder withParameter(String key, Object value) { + this.input.put(key, value); + return this; + } + + public Builder withParameters(Map params) { + this.input.putAll(params); + return this; + } + + public Builder webhook(String webhook) { + this.webhook = webhook; + return this; + } + + public Builder webhookEventsFilter(List webhookEventsFilter) { + this.webhookEventsFilter = webhookEventsFilter; + return this; + } + + public ReplicateOptions build() { + return new ReplicateOptions(this); + } + + } + +} diff --git a/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateOptionsUtils.java b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateOptionsUtils.java new file mode 100644 index 00000000000..17c5ea5a2da --- /dev/null +++ b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateOptionsUtils.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.util.HashMap; +import java.util.Map; + +import org.springframework.lang.Nullable; + +/** + * Utility class for handling Replicate options, when set via application.properties. This + * utility is needed because replicate expects various different Types in the "input" map + * and we cannot automatically infer the type from the properties. + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +public abstract class ReplicateOptionsUtils { + + private ReplicateOptionsUtils() { + } + + /** + * Convert all string values in a map to their appropriate types + * @param source the source map with potentially string-typed values + * @return a new map with properly typed values, or an empty map if source is null + */ + public static Map convertMapValues(@Nullable Map source) { + if (source == null) { + return new HashMap<>(); + } + Map result = new HashMap<>(source.size()); + for (Map.Entry entry : source.entrySet()) { + result.put(entry.getKey(), convertValue(entry.getValue())); + } + return result; + } + + /** + * Convert a value to its appropriate type if it's a string representation of a number + * or boolean. Non-string values are returned as-is. + * @param value the value to convert + * @return the converted value with the appropriate type + */ + public static Object convertValue(@Nullable Object value) { + if (!(value instanceof String strValue)) { + return value; + } + if ("true".equalsIgnoreCase(strValue) || "false".equalsIgnoreCase(strValue)) { + return Boolean.parseBoolean(strValue); + } + if (!strValue.contains(".") && !strValue.contains("e") && !strValue.contains("E")) { + try { + return Integer.parseInt(strValue); + } + catch (NumberFormatException ex) { + // Not an integer, continue to next type check + } + } + try { + return Double.parseDouble(strValue); + } + catch (NumberFormatException ex) { + // Not a number, return as string + } + // Return as string + return strValue; + } + +} diff --git a/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateStringModel.java b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateStringModel.java new file mode 100644 index 00000000000..2518b3af4c8 --- /dev/null +++ b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateStringModel.java @@ -0,0 +1,174 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.util.HashMap; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.replicate.api.ReplicateApi; +import org.springframework.ai.replicate.api.ReplicateApi.PredictionRequest; +import org.springframework.ai.replicate.api.ReplicateApi.PredictionResponse; +import org.springframework.util.Assert; + +/** + * Replicate String Model implementation for models that return simple string outputs. + * Typically used by classifiers, filters, or small utility models. + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +public class ReplicateStringModel { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final ReplicateApi replicateApi; + + private final ReplicateOptions defaultOptions; + + public ReplicateStringModel(ReplicateApi replicateApi, ReplicateOptions defaultOptions) { + Assert.notNull(replicateApi, "replicateApi must not be null"); + this.replicateApi = replicateApi; + this.defaultOptions = defaultOptions; + } + + /** + * Generate a string output using the specified model and options. + * @param options The model configuration including model name, and input + * @return Response containing the string output + */ + public StringResponse generate(ReplicateOptions options) { + ReplicateOptions mergedOptions = mergeOptions(options); + Assert.hasText(mergedOptions.getModel(), "model name must not be empty"); + + PredictionRequest request = new PredictionRequest(mergedOptions.getVersion(), mergedOptions.getInput(), + mergedOptions.getWebhook(), mergedOptions.getWebhookEventsFilter(), false); + + PredictionResponse predictionResponse = this.replicateApi.createPredictionAndWait(mergedOptions.getModel(), + request); + + if (predictionResponse == null) { + logger.warn("No prediction response returned for model: {}", mergedOptions.getModel()); + return new StringResponse(null, predictionResponse); + } + + String output = parseStringOutput(predictionResponse.output()); + return new StringResponse(output, predictionResponse); + } + + /** + * Merges default options from properties with prompt options. Prompt options take + * precedence + * @param providedOptions Options from the current Prompt + * @return merged Options + */ + private ReplicateOptions mergeOptions(ReplicateOptions providedOptions) { + if (this.defaultOptions == null) { + return providedOptions != null ? providedOptions : ReplicateOptions.builder().build(); + } + if (providedOptions == null) { + return this.defaultOptions; + } + ReplicateOptions merged = ReplicateOptions.fromOptions(this.defaultOptions); + if (providedOptions.getModel() != null) { + merged.setModel(providedOptions.getModel()); + } + if (providedOptions.getVersion() != null) { + merged.setVersion(providedOptions.getVersion()); + } + if (providedOptions.getWebhook() != null) { + merged.setWebhook(providedOptions.getWebhook()); + } + if (providedOptions.getWebhookEventsFilter() != null) { + merged.setWebhookEventsFilter(providedOptions.getWebhookEventsFilter()); + } + Map mergedInput = new HashMap<>(); + if (this.defaultOptions.getInput() != null) { + mergedInput.putAll(this.defaultOptions.getInput()); + } + if (providedOptions.getInput() != null) { + mergedInput.putAll(providedOptions.getInput()); + } + merged.setInput(mergedInput); + + return merged; + } + + private String parseStringOutput(Object output) { + if (output == null) { + return null; + } + if (output instanceof String outputString) { + return outputString; + } + logger.warn("Unexpected output type for string model: {}, converting to string", output.getClass().getName()); + return output.toString(); + } + + public static Builder builder() { + return new Builder(); + } + + public static class StringResponse { + + private final String output; + + private final PredictionResponse predictionResponse; + + public StringResponse(String output, PredictionResponse predictionResponse) { + this.output = output; + this.predictionResponse = predictionResponse; + } + + public String getOutput() { + return this.output; + } + + public PredictionResponse getPredictionResponse() { + return this.predictionResponse; + } + + } + + public static final class Builder { + + private ReplicateApi replicateApi; + + private ReplicateOptions defaultOptions; + + private Builder() { + } + + public Builder replicateApi(ReplicateApi replicateApi) { + this.replicateApi = replicateApi; + return this; + } + + public Builder defaultOptions(ReplicateOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + public ReplicateStringModel build() { + return new ReplicateStringModel(this.replicateApi, this.defaultOptions); + } + + } + +} diff --git a/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateStructuredModel.java b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateStructuredModel.java new file mode 100644 index 00000000000..1375f426aa4 --- /dev/null +++ b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/ReplicateStructuredModel.java @@ -0,0 +1,220 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.replicate.api.ReplicateApi; +import org.springframework.ai.replicate.api.ReplicateApi.PredictionRequest; +import org.springframework.ai.replicate.api.ReplicateApi.PredictionResponse; +import org.springframework.util.Assert; + +/** + * Replicate Structured Model implementation for models that return structured JSON + * objects. + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +public class ReplicateStructuredModel { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final ReplicateApi replicateApi; + + private final ReplicateOptions defaultOptions; + + public ReplicateStructuredModel(ReplicateApi replicateApi, ReplicateOptions defaultOptions) { + Assert.notNull(replicateApi, "replicateApi must not be null"); + this.replicateApi = replicateApi; + this.defaultOptions = defaultOptions; + } + + /** + * Generate structured output using the specified model and options. + * @param options The model configuration including model name and input + * @return Response containing the structured output as a Map + */ + public StructuredResponse generate(ReplicateOptions options) { + ReplicateOptions mergedOptions = mergeOptions(options); + Assert.hasText(mergedOptions.getModel(), "model name must not be empty"); + + PredictionRequest request = new PredictionRequest(mergedOptions.getVersion(), mergedOptions.getInput(), + mergedOptions.getWebhook(), mergedOptions.getWebhookEventsFilter(), false); + + PredictionResponse predictionResponse = this.replicateApi.createPredictionAndWait(mergedOptions.getModel(), + request); + + if (predictionResponse == null) { + logger.warn("No prediction response returned for model: {}", mergedOptions.getModel()); + return new StructuredResponse(Collections.emptyMap(), predictionResponse); + } + + Map structuredOutput = parseStructuredOutput(predictionResponse.output()); + return new StructuredResponse(structuredOutput, predictionResponse); + } + + /** + * Merges default options from properties with prompt options. Prompt options take + * precedence + * @param providedOptions Options from the current Prompt + * @return merged Options + */ + private ReplicateOptions mergeOptions(ReplicateOptions providedOptions) { + if (this.defaultOptions == null) { + return providedOptions != null ? providedOptions : ReplicateOptions.builder().build(); + } + if (providedOptions == null) { + return this.defaultOptions; + } + ReplicateOptions merged = ReplicateOptions.fromOptions(this.defaultOptions); + if (providedOptions.getModel() != null) { + merged.setModel(providedOptions.getModel()); + } + if (providedOptions.getVersion() != null) { + merged.setVersion(providedOptions.getVersion()); + } + if (providedOptions.getWebhook() != null) { + merged.setWebhook(providedOptions.getWebhook()); + } + if (providedOptions.getWebhookEventsFilter() != null) { + merged.setWebhookEventsFilter(providedOptions.getWebhookEventsFilter()); + } + Map mergedInput = new HashMap<>(); + if (this.defaultOptions.getInput() != null) { + mergedInput.putAll(this.defaultOptions.getInput()); + } + if (providedOptions.getInput() != null) { + mergedInput.putAll(providedOptions.getInput()); + } + merged.setInput(mergedInput); + + return merged; + } + + /** + * Parse the output field as a Map. + */ + @SuppressWarnings("unchecked") + private Map parseStructuredOutput(Object output) { + if (output == null) { + return Collections.emptyMap(); + } + + if (output instanceof Map) { + return (Map) output; + } + + logger.warn("Unexpected output type for structured model: {}", output.getClass().getName()); + return Collections.emptyMap(); + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Response containing structured output as a Map and the raw prediction response. + */ + public static class StructuredResponse { + + private final Map output; + + private final PredictionResponse predictionResponse; + + public StructuredResponse(Map output, PredictionResponse predictionResponse) { + this.output = output != null ? Collections.unmodifiableMap(output) : Collections.emptyMap(); + this.predictionResponse = predictionResponse; + } + + /** + * Get the structured output as a Map. + */ + public Map getOutput() { + return this.output; + } + + /** + * Get the raw prediction response from Replicate API. + */ + public PredictionResponse getPredictionResponse() { + return this.predictionResponse; + } + + } + + /** + * Response containing structured output converted to a specific type. + */ + public static class TypedStructuredResponse { + + private final T output; + + private final PredictionResponse predictionResponse; + + public TypedStructuredResponse(T output, PredictionResponse predictionResponse) { + this.output = output; + this.predictionResponse = predictionResponse; + } + + /** + * Get the structured output as the specified type. + */ + public T getOutput() { + return this.output; + } + + /** + * Get the raw prediction response from Replicate API. + */ + public PredictionResponse getPredictionResponse() { + return this.predictionResponse; + } + + } + + public static final class Builder { + + private ReplicateApi replicateApi; + + private ReplicateOptions defaultOptions; + + private Builder() { + } + + public Builder replicateApi(ReplicateApi replicateApi) { + this.replicateApi = replicateApi; + return this; + } + + public Builder defaultOptions(ReplicateOptions defaultOptions) { + this.defaultOptions = defaultOptions; + return this; + } + + public ReplicateStructuredModel build() { + return new ReplicateStructuredModel(this.replicateApi, this.defaultOptions); + } + + } + +} diff --git a/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/api/ReplicateApi.java b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/api/ReplicateApi.java new file mode 100644 index 00000000000..1719e040360 --- /dev/null +++ b/models/spring-ai-replicate/src/main/java/org/springframework/ai/replicate/api/ReplicateApi.java @@ -0,0 +1,424 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate.api; + +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +import org.springframework.ai.model.ApiKey; +import org.springframework.ai.model.SimpleApiKey; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.Resource; +import org.springframework.http.ContentDisposition; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.http.client.MultipartBodyBuilder; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +/** + * Client for the Replicate Predictions API + * + * @author Rene Maierhofer + * @since 1.1.0 + */ +public final class ReplicateApi { + + private static final Logger logger = LoggerFactory.getLogger(ReplicateApi.class); + + private static final String DEFAULT_BASE_URL = "https://api.replicate.com/v1"; + + private static final String PREDICTIONS_PATH = "/predictions"; + + private final RestClient restClient; + + private final WebClient webClient; + + private final RetryTemplate retryTemplate = RetryTemplate.builder() + .retryOn(ReplicatePredictionNotFinishedException.class) + .maxAttempts(60) + .fixedBackoff(5000) + .withListener(new RetryListener() { + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + logger.debug("Polling Replicate Prediction: {}/10 attempts.", context.getRetryCount()); + } + }) + .build(); + + public static final String PROVIDER_NAME = AiProvider.REPLICATE.value(); + + private ReplicateApi(String baseUrl, ApiKey apiKey, RestClient.Builder restClientBuilder, + WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { + Consumer headers = h -> { + h.setContentType(MediaType.APPLICATION_JSON); + h.setBearerAuth(apiKey.getValue()); + }; + + this.restClient = restClientBuilder.baseUrl(baseUrl) + .defaultHeaders(headers) + .defaultStatusHandler(responseErrorHandler) + .build(); + + this.webClient = webClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(headers).build(); + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Creates a Prediction using the model's endpoint. + * @param modelName The model name in format "owner/name" (e.g., "openai/gpt-5") + * @param request The prediction request + * @return The prediction response + */ + public PredictionResponse createPrediction(String modelName, PredictionRequest request) { + Assert.hasText(modelName, "Model name must not be empty"); + + String uri = "/models/" + modelName + PREDICTIONS_PATH; + ResponseEntity response = this.restClient.post() + .uri(uri) + .body(request) + .retrieve() + .toEntity(PredictionResponse.class); + return response.getBody(); + } + + /** + * Retrieves the current status of the Prediction + * @param predictionId The prediction ID + * @return The prediction response + */ + public PredictionResponse getPrediction(String predictionId) { + Assert.hasText(predictionId, "Prediction ID must not be empty"); + + return this.restClient.get() + .uri(PREDICTIONS_PATH + "/{id}", predictionId) + .retrieve() + .body(PredictionResponse.class); + } + + /** + * Creates a prediction and waits for it to complete by polling the status. Uses the + * configured retry template. + * @param modelName The model name in format "owner/name" + * @param request The prediction request + * @return The completed prediction response + */ + public PredictionResponse createPredictionAndWait(String modelName, PredictionRequest request) { + PredictionResponse prediction = createPrediction(modelName, request); + if (prediction == null || prediction.id == null) { + throw new ReplicatePredictionException("PredictionRequest did not return a valid response."); + } + return waitForCompletion(prediction.id()); + } + + /** + * Waits for the completed Prediction and returns the final Response. + * @param predictionId id of the prediction + * @return the final PredictionResponse + */ + public PredictionResponse waitForCompletion(String predictionId) { + Assert.hasText(predictionId, "Prediction ID must not be empty"); + return this.retryTemplate.execute(context -> pollStatusFromReplicate(predictionId)); + } + + /** + * Polls the prediction status from replicate. + * @param predictionId the Prediction's id + * @return the final Prediction Response + */ + private PredictionResponse pollStatusFromReplicate(String predictionId) { + PredictionResponse prediction = getPrediction(predictionId); + if (prediction == null || prediction.id == null) { + throw new ReplicatePredictionException("Polling for Prediction did not return a valid response."); + } + PredictionStatus status = prediction.status(); + if (status == PredictionStatus.SUCCEEDED) { + return prediction; + } + else if (status == PredictionStatus.PROCESSING || status == PredictionStatus.STARTING) { + throw new ReplicatePredictionNotFinishedException("Prediction not finished yet."); + } + else if (status == PredictionStatus.FAILED) { + String error = prediction.error() != null ? prediction.error() : "Unknown error"; + throw new ReplicatePredictionException("Prediction failed: " + error); + } + else if (status == PredictionStatus.CANCELED || status == PredictionStatus.ABORTED) { + throw new ReplicatePredictionException("Prediction was canceled"); + } + throw new ReplicatePredictionException("Unknown Replicate Prediction Status"); + } + + /** + * Uploads a file to Replicate for usage in a request. Replicate + * Files API + * @param fileResource The file to upload + * @param filename The filename to use for the uploaded file + * @return Upload response containing the URL to later send with a request. + */ + public FileUploadResponse uploadFile(Resource fileResource, String filename) { + Assert.notNull(fileResource, "File resource must not be null"); + Assert.hasText(filename, "Filename must not be empty"); + + MultipartBodyBuilder builder = new MultipartBodyBuilder(); + builder.part("content", fileResource) + .headers(h -> h + .setContentDisposition(ContentDisposition.formData().name("content").filename(filename).build())) + .contentType(MediaType.APPLICATION_OCTET_STREAM); + + return this.webClient.post() + .uri("/files") + .contentType(MediaType.MULTIPART_FORM_DATA) + .bodyValue(builder.build()) + .retrieve() + .bodyToMono(FileUploadResponse.class) + .block(); + } + + /** + * Creates a streaming prediction response. Replicate uses SSE for Streaming. + * Replicate + * Docs + * @param modelName The model name in format "owner/name" + * @param request The prediction request (must have stream=true) + * @return A Flux stream of prediction response events with incremental output + */ + public Flux createPredictionStream(String modelName, PredictionRequest request) { + PredictionResponse initialResponse = createPrediction(modelName, request); + if (initialResponse.urls() == null || initialResponse.urls().stream() == null) { + logger.error("No stream URL in response: {}", initialResponse); + return Flux.error(new ReplicatePredictionException("No stream URL returned from prediction")); + } + String streamUrl = initialResponse.urls().stream(); + ParameterizedTypeReference> typeRef = new ParameterizedTypeReference<>() { + }; + + return this.webClient.get() + .uri(streamUrl) + .accept(MediaType.TEXT_EVENT_STREAM) + .header(HttpHeaders.CACHE_CONTROL, "no-store") + .retrieve() + .bodyToFlux(typeRef) + .handle((event, sink) -> { + String eventType = event.event(); + if ("error".equals(eventType)) { + String errorMessage = event.data() != null ? event.data() : "Unknown error"; + sink.error(new ReplicatePredictionException("Streaming error: " + errorMessage)); + return; + } + if ("done".equals(eventType)) { + sink.complete(); + return; + } + if ("output".equals(eventType)) { + String dataContent = event.data() != null ? event.data() : ""; + PredictionResponse response = new PredictionResponse(initialResponse.id(), initialResponse.model(), + initialResponse.version(), PredictionStatus.PROCESSING, initialResponse.input(), + dataContent, // The output chunk + null, null, null, initialResponse.urls(), initialResponse.createdAt(), + initialResponse.startedAt(), null); + + sink.next(response); + } + }); + } + + /** + * Request to create a prediction + * + * @param version Optional model version + * @param input The input parameters for the model + * @param webhook Optional webhook URL for async notifications + * @param webhookEventsFilter Optional list of webhook events to subscribe to + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record PredictionRequest(@JsonProperty("version") String version, + @JsonProperty("input") Map input, @JsonProperty("webhook") String webhook, + @JsonProperty("webhook_events_filter") List webhookEventsFilter, + @JsonProperty("stream") boolean stream) { + } + + /** + * Response from Replicate prediction API. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record PredictionResponse(@JsonProperty("id") String id, @JsonProperty("model") String model, + @JsonProperty("version") String version, @JsonProperty("status") PredictionStatus status, + @JsonProperty("input") Map input, @JsonProperty("output") Object output, + @JsonProperty("error") String error, @JsonProperty("logs") String logs, + @JsonProperty("metrics") Metrics metrics, @JsonProperty("urls") Urls urls, + @JsonProperty("created_at") String createdAt, @JsonProperty("started_at") String startedAt, + @JsonProperty("completed_at") String completedAt) { + } + + /** + * Prediction status. + */ + public enum PredictionStatus { + + @JsonProperty("starting") + STARTING, + + @JsonProperty("processing") + PROCESSING, + + @JsonProperty("succeeded") + SUCCEEDED, + + @JsonProperty("failed") + FAILED, + + @JsonProperty("canceled") + CANCELED, + + @JsonProperty("aborted") + ABORTED + + } + + /** + * Metrics from a prediction including token counts and timing. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Metrics(@JsonProperty("predict_time") Double predictTime, + @JsonProperty("input_token_count") Integer inputTokenCount, + @JsonProperty("output_token_count") Integer outputTokenCount) { + } + + /** + * URLs for interacting with a prediction. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Urls(@JsonProperty("get") String get, @JsonProperty("cancel") String cancel, + @JsonProperty("stream") String stream) { + } + + /** + * Response from Replicate file upload API. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record FileUploadResponse(@JsonProperty("id") String id, @JsonProperty("name") String name, + @JsonProperty("content_type") String contentType, @JsonProperty("size") Long size, + @JsonProperty("urls") FileUrls urls, @JsonProperty("created_at") String createdAt, + @JsonProperty("expires_at") String expiresAt) { + } + + /** + * URLs for accessing an uploaded file. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record FileUrls(@JsonProperty("get") String get) { + } + + /** + * Builder to Construct a {@link ReplicateApi} instance + */ + public static final class Builder { + + private String baseUrl = DEFAULT_BASE_URL; + + private ApiKey apiKey; + + private RestClient.Builder restClientBuilder = RestClient.builder(); + + private WebClient.Builder webClientBuilder = WebClient.builder(); + + private ResponseErrorHandler responseErrorHandler = RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER; + + public Builder baseUrl(String baseUrl) { + Assert.hasText(baseUrl, "baseUrl cannot be empty"); + this.baseUrl = baseUrl; + return this; + } + + public Builder apiKey(String apiKey) { + Assert.notNull(apiKey, "ApiKey cannot be null"); + this.apiKey = new SimpleApiKey(apiKey); + return this; + } + + public Builder restClientBuilder(RestClient.Builder restClientBuilder) { + Assert.notNull(restClientBuilder, "restClientBuilder cannot be null"); + this.restClientBuilder = restClientBuilder; + return this; + } + + public Builder webClientBuilder(WebClient.Builder webClientBuilder) { + Assert.notNull(webClientBuilder, "webClientBuilder cannot be null"); + this.webClientBuilder = webClientBuilder; + return this; + } + + public Builder responseErrorHandler(ResponseErrorHandler responseErrorHandler) { + Assert.notNull(responseErrorHandler, "responseErrorHandler cannot be null"); + this.responseErrorHandler = responseErrorHandler; + return this; + } + + public ReplicateApi build() { + Assert.notNull(this.apiKey, "cannot construct instance without apiKey"); + return new ReplicateApi(this.baseUrl, this.apiKey, this.restClientBuilder, this.webClientBuilder, + this.responseErrorHandler); + } + + } + + /** + * Exception thrown when a Replicate prediction fails or times out. + */ + public static class ReplicatePredictionException extends RuntimeException { + + public ReplicatePredictionException(String message) { + super(message); + } + + } + + /** + * Exception thrown when a Replicate prediction has not finished yet. Used for + * RetryTemplate. + */ + public static class ReplicatePredictionNotFinishedException extends RuntimeException { + + public ReplicatePredictionNotFinishedException(String message) { + super(message); + } + + } + +} diff --git a/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateChatModelIT.java b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateChatModelIT.java new file mode 100644 index 00000000000..23108e94b67 --- /dev/null +++ b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateChatModelIT.java @@ -0,0 +1,131 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.util.List; +import java.util.Objects; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link ReplicateChatModel}. + * + * @author Rene Maierhofer + */ +@SpringBootTest(classes = ReplicateTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "REPLICATE_API_TOKEN", matches = ".+") +class ReplicateChatModelIT { + + @Autowired + private ReplicateChatModel chatModel; + + @Test + void testSimpleCall() { + String userMessage = "What is the capital of France? Answer in one word."; + ChatResponse response = this.chatModel.call(new Prompt(userMessage)); + + assertThat(response).isNotNull(); + assertThat(response.getResults()).isNotEmpty(); + + Generation generation = response.getResult(); + assertThat(generation).isNotNull(); + Assertions.assertNotNull(generation.getOutput().getText()); + Assertions.assertFalse(generation.getOutput().getText().isEmpty()); + assertThat(generation.getOutput().getText().toLowerCase()).contains("paris"); + + ChatResponseMetadata metadata = response.getMetadata(); + assertThat(metadata).isNotNull(); + assertThat(metadata.getId()).isNotNull(); + assertThat(metadata.getModel()).isNotNull(); + assertThat(metadata.getUsage()).isNotNull(); + assertThat(metadata.getUsage().getPromptTokens()).isGreaterThanOrEqualTo(0); + assertThat(metadata.getUsage().getCompletionTokens()).isGreaterThanOrEqualTo(0); + } + + @Test + void testStreamingCall() { + String userMessage = "Count from 1 to 500."; + Flux responseFlux = this.chatModel.stream(new Prompt(userMessage)); + List responses = responseFlux.collectList().block(); + assertThat(responses).isNotNull().isNotEmpty().hasSizeGreaterThan(1); + + responses.forEach(response -> { + assertThat(response.getResults()).isNotEmpty(); + assertThat(response.getResult().getOutput().getText()).isNotNull(); + }); + + List chunks = responses.stream() + .flatMap(chatResponse -> chatResponse.getResults().stream()) + .map(generation -> generation.getOutput().getText()) + .filter(Objects::nonNull) + .filter(text -> !text.isEmpty()) + .toList(); + + assertThat(chunks).hasSizeGreaterThan(1); + + String fullContent = String.join("", chunks); + assertThat(fullContent).isNotEmpty(); + + boolean hasMetadata = responses.stream().anyMatch(response -> response.getMetadata().getId() != null); + assertThat(hasMetadata).isTrue(); + } + + @Test + void testCallWithOptions() { + int maxTokens = 10; + ReplicateChatOptions options = ReplicateChatOptions.builder() + .model("meta/meta-llama-3-8b-instruct") + .withParameter("temperature", 0.8) + .withParameter("max_tokens", maxTokens) + .build(); + + Prompt prompt = new Prompt("Write a very long poem.", options); + ChatResponse response = this.chatModel.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getResults()).isNotEmpty(); + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + ChatResponseMetadata metadata = response.getMetadata(); + Usage usage = metadata.getUsage(); + assertThat(usage.getCompletionTokens()).isLessThanOrEqualTo(maxTokens); + } + + @Test + void testMultiTurnConversation_shouldNotWork() { + UserMessage userMessage1 = new UserMessage("My favorite color is blue."); + AssistantMessage assistantMessage = new AssistantMessage("Noted!"); + UserMessage userMessage2 = new UserMessage("What is my favorite color?"); + Prompt prompt = new Prompt(List.of(userMessage1, assistantMessage, userMessage2)); + Assertions.assertThrows(AssertionError.class, () -> this.chatModel.call(prompt)); + } + +} diff --git a/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateChatOptionsTests.java b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateChatOptionsTests.java new file mode 100644 index 00000000000..cc60a098581 --- /dev/null +++ b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateChatOptionsTests.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ReplicateChatOptions}. + * + * @author Rene Maierhofer + */ +class ReplicateChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + Map inputParams = new HashMap<>(); + inputParams.put("temperature", 0.7); + inputParams.put("max_tokens", 100); + + List webhookEvents = Arrays.asList("start", "completed"); + + ReplicateChatOptions options = ReplicateChatOptions.builder() + .model("meta/llama-3-8b-instruct") + .version("1234abcd") + .withParameters(inputParams) + .webhook("https://127.0.0.1/webhook") + .webhookEventsFilter(webhookEvents) + .build(); + + assertThat(options).extracting("model", "version", "webhook") + .containsExactly("meta/llama-3-8b-instruct", "1234abcd", "https://127.0.0.1/webhook"); + + assertThat(options.getInput()).containsEntry("temperature", 0.7).containsEntry("max_tokens", 100); + + assertThat(options.getWebhookEventsFilter()).containsExactly("start", "completed"); + } + + @Test + void testWithParameter() { + ReplicateChatOptions options = ReplicateChatOptions.builder().model("test-model").build(); + + options.withParameter("temperature", 0.8); + options.withParameter("max_tokens", 200); + + assertThat(options.getInput()).hasSize(2).containsEntry("temperature", 0.8).containsEntry("max_tokens", 200); + } + + @Test + void testWithParametersMap() { + Map params = new HashMap<>(); + params.put("param1", "value1"); + params.put("param2", 42); + + ReplicateChatOptions options = ReplicateChatOptions.builder().model("test-model").build(); + + options.withParameters(params); + + assertThat(options.getInput()).hasSize(2).containsEntry("param1", "value1").containsEntry("param2", 42); + } + + @Test + void testFromOptions() { + Map inputParams = new HashMap<>(); + inputParams.put("temperature", 0.7); + + List webhookEvents = Arrays.asList("start", "completed"); + + ReplicateChatOptions original = ReplicateChatOptions.builder() + .model("meta/llama-3-8b-instruct") + .version("1234abcd") + .withParameters(inputParams) + .webhook("https://127.0.0.1/webhook") + .webhookEventsFilter(webhookEvents) + .build(); + + ReplicateChatOptions copy = ReplicateChatOptions.fromOptions(original); + + assertThat(copy).isNotSameAs(original); + assertThat(copy.getModel()).isEqualTo(original.getModel()); + assertThat(copy.getVersion()).isEqualTo(original.getVersion()); + assertThat(copy.getWebhook()).isEqualTo(original.getWebhook()); + assertThat(copy.getWebhookEventsFilter()).isEqualTo(original.getWebhookEventsFilter()); + assertThat(copy.getInput()).isNotSameAs(original.getInput()).isEqualTo(original.getInput()); + } + + @Test + void testSetInputConvertsStringValues() { + ReplicateChatOptions options = new ReplicateChatOptions(); + + Map input = new HashMap<>(); + input.put("temperature", "0.7"); + input.put("maxTokens", "100"); + input.put("enabled", "true"); + + options.setInput(input); + + assertThat(options.getInput().get("temperature")).isInstanceOf(Double.class).isEqualTo(0.7); + assertThat(options.getInput().get("maxTokens")).isInstanceOf(Integer.class).isEqualTo(100); + assertThat(options.getInput().get("enabled")).isInstanceOf(Boolean.class).isEqualTo(true); + } + +} diff --git a/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateMediaModelIT.java b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateMediaModelIT.java new file mode 100644 index 00000000000..c8f4dc4ffce --- /dev/null +++ b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateMediaModelIT.java @@ -0,0 +1,73 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.replicate.ReplicateMediaModel.MediaResponse; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link ReplicateMediaModel}. + * + * @author Rene Maierhofer + */ +@SpringBootTest(classes = ReplicateTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "REPLICATE_API_TOKEN", matches = ".+") +class ReplicateMediaModelIT { + + @Autowired + private ReplicateMediaModel mediaModel; + + @Test + void testGenerateMultipleImages() { + ReplicateOptions options = ReplicateOptions.builder() + .model("black-forest-labs/flux-schnell") + .withParameter("prompt", "a cat sitting on a laptop") + .withParameter("num_outputs", 2) + .build(); + + MediaResponse response = this.mediaModel.generate(options); + + assertThat(response).isNotNull(); + assertThat(response.getUris()).isNotEmpty(); + response.getUris().forEach(uri -> { + assertThat(uri).isNotEmpty(); + assertThat(uri).startsWith("http"); + }); + assertThat(response.getPredictionResponse()).isNotNull(); + assertThat(response.getPredictionResponse().id()).isNotNull(); + assertThat(response.getPredictionResponse().createdAt()).isNotNull(); + assertThat(response.getPredictionResponse().status()).isNotNull(); + assertThat(response.getPredictionResponse().model()).contains("black-forest-labs/flux-schnell"); + } + + @Test + void testGenerateWithDefaultOptions() { + ReplicateOptions options = ReplicateOptions.builder().withParameter("prompt", "a serene lake").build(); + + MediaResponse response = this.mediaModel.generate(options); + + assertThat(response).isNotNull(); + assertThat(response.getUris()).isNotEmpty(); + } + +} diff --git a/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateOptionsTests.java b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateOptionsTests.java new file mode 100644 index 00000000000..9b41d3a3a64 --- /dev/null +++ b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateOptionsTests.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ReplicateOptions}. + * + * @author Rene Maierhofer + */ +class ReplicateOptionsTests { + + @Test + void testBuilderWithAllFields() { + Map inputParams = new HashMap<>(); + inputParams.put("temperature", 0.7); + inputParams.put("max_tokens", 100); + + List webhookEvents = Arrays.asList("start", "completed"); + + ReplicateOptions options = ReplicateOptions.builder() + .model("meta/llama-3-8b-instruct") + .version("1234abcd") + .withParameters(inputParams) + .webhook("https://example.com/webhook") + .webhookEventsFilter(webhookEvents) + .build(); + + assertThat(options).extracting("model", "version", "webhook") + .containsExactly("meta/llama-3-8b-instruct", "1234abcd", "https://example.com/webhook"); + + assertThat(options.getInput()).containsEntry("temperature", 0.7).containsEntry("max_tokens", 100); + + assertThat(options.getWebhookEventsFilter()).containsExactly("start", "completed"); + } + + @Test + void testWithParameter() { + ReplicateOptions options = ReplicateOptions.builder().model("test-model").build(); + + options.withParameter("temperature", 0.8); + options.withParameter("max_tokens", 200); + + assertThat(options.getInput()).hasSize(2).containsEntry("temperature", 0.8).containsEntry("max_tokens", 200); + } + + @Test + void testWithParametersMap() { + Map params = new HashMap<>(); + params.put("param1", "value1"); + params.put("param2", 42); + + ReplicateOptions options = ReplicateOptions.builder().model("test-model").build(); + + options.withParameters(params); + + assertThat(options.getInput()).hasSize(2).containsEntry("param1", "value1").containsEntry("param2", 42); + } + + @Test + void testFromOptions() { + Map inputParams = new HashMap<>(); + inputParams.put("temperature", 0.7); + + List webhookEvents = Arrays.asList("start", "completed"); + + ReplicateOptions original = ReplicateOptions.builder() + .model("meta/llama-3-8b-instruct") + .version("1234abcd") + .withParameters(inputParams) + .webhook("https://example.com/webhook") + .webhookEventsFilter(webhookEvents) + .build(); + + ReplicateOptions copy = ReplicateOptions.fromOptions(original); + + assertThat(copy).isNotSameAs(original); + assertThat(copy.getModel()).isEqualTo(original.getModel()); + assertThat(copy.getVersion()).isEqualTo(original.getVersion()); + assertThat(copy.getWebhook()).isEqualTo(original.getWebhook()); + assertThat(copy.getWebhookEventsFilter()).isEqualTo(original.getWebhookEventsFilter()); + assertThat(copy.getInput()).isNotSameAs(original.getInput()).isEqualTo(original.getInput()); + } + + @Test + void testSetInputConvertsStringValues() { + ReplicateOptions options = new ReplicateOptions(); + + Map input = new HashMap<>(); + input.put("temperature", "0.7"); + input.put("maxTokens", "100"); + input.put("enabled", "true"); + + options.setInput(input); + + assertThat(options.getInput().get("temperature")).isInstanceOf(Double.class).isEqualTo(0.7); + assertThat(options.getInput().get("maxTokens")).isInstanceOf(Integer.class).isEqualTo(100); + assertThat(options.getInput().get("enabled")).isInstanceOf(Boolean.class).isEqualTo(true); + } + +} diff --git a/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateOptionsUtilsTests.java b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateOptionsUtilsTests.java new file mode 100644 index 00000000000..9596394e59e --- /dev/null +++ b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateOptionsUtilsTests.java @@ -0,0 +1,89 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ReplicateOptionsUtils}. + * + * @author Rene Maierhofer + */ +class ReplicateOptionsUtilsTests { + + @Test + void testConvertValueWithBoolean() { + assertThat(ReplicateOptionsUtils.convertValue("true")).isInstanceOf(Boolean.class).isEqualTo(true); + assertThat(ReplicateOptionsUtils.convertValue("false")).isInstanceOf(Boolean.class).isEqualTo(false); + assertThat(ReplicateOptionsUtils.convertValue("TRUE")).isEqualTo(true); + assertThat(ReplicateOptionsUtils.convertValue("False")).isEqualTo(false); + assertThat(ReplicateOptionsUtils.convertValue("TrUe")).isEqualTo(true); + } + + @Test + void testConvertValueNumeric() { + assertThat(ReplicateOptionsUtils.convertValue("42")).isInstanceOf(Integer.class).isEqualTo(42); + assertThat(ReplicateOptionsUtils.convertValue("3.14")).isInstanceOf(Double.class).isEqualTo(3.14); + assertThat(ReplicateOptionsUtils.convertValue("1.5e10")).isInstanceOf(Double.class).isEqualTo(1.5E10); + } + + @Test + void testConvertValueWithPlainString() { + assertThat(ReplicateOptionsUtils.convertValue("hello world")).isInstanceOf(String.class) + .isEqualTo("hello world"); + } + + @Test + void testConvertValueWithNonStringValue() { + Integer intValue = 100; + Object result = ReplicateOptionsUtils.convertValue(intValue); + assertThat(result).isSameAs(intValue); + + Double doubleValue = 5.5; + result = ReplicateOptionsUtils.convertValue(doubleValue); + assertThat(result).isSameAs(doubleValue); + + Boolean boolValue = true; + result = ReplicateOptionsUtils.convertValue(boolValue); + assertThat(result).isSameAs(boolValue); + } + + @Test + void testConvertMapValuesWithMixedTypes() { + Map source = new HashMap<>(); + source.put("temperature", "0.7"); + source.put("maxTokens", "100"); + source.put("enabled", "true"); + source.put("model", "meta/llama-3"); + source.put("existingInt", 42); + + Map result = ReplicateOptionsUtils.convertMapValues(source); + + assertThat(result).hasSize(5); + assertThat(result.get("temperature")).isInstanceOf(Double.class).isEqualTo(0.7); + assertThat(result.get("maxTokens")).isInstanceOf(Integer.class).isEqualTo(100); + assertThat(result.get("enabled")).isInstanceOf(Boolean.class).isEqualTo(true); + assertThat(result.get("model")).isInstanceOf(String.class).isEqualTo("meta/llama-3"); + assertThat(result.get("existingInt")).isInstanceOf(Integer.class).isEqualTo(42); + } + +} diff --git a/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateStringModelIT.java b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateStringModelIT.java new file mode 100644 index 00000000000..d2d5d3f1329 --- /dev/null +++ b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateStringModelIT.java @@ -0,0 +1,102 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Base64; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.replicate.ReplicateStringModel.StringResponse; +import org.springframework.ai.replicate.api.ReplicateApi; +import org.springframework.ai.replicate.api.ReplicateApi.FileUploadResponse; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.io.FileSystemResource; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link ReplicateStringModel}. + * + * @author Rene Maierhofer + */ +@SpringBootTest(classes = ReplicateTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "REPLICATE_API_TOKEN", matches = ".+") +class ReplicateStringModelIT { + + @Autowired + private ReplicateStringModel stringModel; + + @Autowired + private ReplicateApi replicateApi; + + @Test + void testClassifyImageWithFileUpload() { + Path imagePath = Paths.get("src/test/resources/test-image.jpg"); + FileSystemResource fileResource = new FileSystemResource(imagePath); + + FileUploadResponse uploadResponse = this.replicateApi.uploadFile(fileResource, "test-image.jpg"); + + assertThat(uploadResponse).isNotNull(); + assertThat(uploadResponse.urls()).isNotNull(); + assertThat(uploadResponse.urls().get()).isNotEmpty(); + + String imageUrl = uploadResponse.urls().get(); + + ReplicateOptions options = ReplicateOptions.builder() + .model("falcons-ai/nsfw_image_detection") + .withParameter("image", imageUrl) + .build(); + + StringResponse response = this.stringModel.generate(options); + + // Validate output + assertThat(response).isNotNull(); + assertThat(response.getOutput()).isNotNull().isInstanceOf(String.class).isNotEmpty(); + assertThat(response.getOutput().toLowerCase()).isEqualTo("normal"); + + // Validate metadata + assertThat(response.getPredictionResponse()).isNotNull(); + assertThat(response.getPredictionResponse().id()).isNotNull(); + } + + @Test + void testClassifyImageWithBase64() throws IOException { + Path imagePath = Paths.get("src/test/resources/test-image.jpg"); + byte[] imageBytes = Files.readAllBytes(imagePath); + String base64Image = "data:application/octet-stream;base64," + Base64.getEncoder().encodeToString(imageBytes); + + ReplicateOptions options = ReplicateOptions.builder() + .model("falcons-ai/nsfw_image_detection") + .withParameter("image", base64Image) + .build(); + + StringResponse response = this.stringModel.generate(options); + + assertThat(response).isNotNull(); + assertThat(response.getOutput()).isNotNull().isInstanceOf(String.class).isNotEmpty(); + assertThat(response.getOutput().toLowerCase()).isEqualTo("normal"); + assertThat(response.getPredictionResponse()).isNotNull(); + assertThat(response.getPredictionResponse().id()).isNotNull(); + } + +} diff --git a/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateStructuredModelIT.java b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateStructuredModelIT.java new file mode 100644 index 00000000000..e2b9c2b0ae2 --- /dev/null +++ b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateStructuredModelIT.java @@ -0,0 +1,73 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import org.springframework.ai.replicate.ReplicateStructuredModel.StructuredResponse; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link ReplicateStructuredModel}. + * + * @author Rene Maierhofer + */ +@SpringBootTest(classes = ReplicateTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "REPLICATE_API_TOKEN", matches = ".+") +class ReplicateStructuredModelIT { + + @Autowired + private ReplicateStructuredModel structuredModel; + + @Test + void testGenerateEmbeddingsWithMetadata() { + ReplicateOptions options = ReplicateOptions.builder() + .model("openai/clip") + .withParameter("text", "spring ai") + .build(); + + StructuredResponse response = this.structuredModel.generate(options); + + assertThat(response).isNotNull(); + assertThat(response.getOutput()).isNotNull().isInstanceOf(Map.class); + Map output = response.getOutput(); + assertThat(output).isNotEmpty(); + assertThat(response.getPredictionResponse()).isNotNull(); + assertThat(response.getPredictionResponse().id()).isNotNull(); + assertThat(response.getPredictionResponse().status()).isNotNull(); + assertThat(response.getPredictionResponse().model()).contains("openai/clip"); + assertThat(response.getPredictionResponse().createdAt()).isNotNull(); + assertThat(response.getPredictionResponse().output()).isNotNull(); + } + + @Test + void testWithDefaultOptions() { + ReplicateOptions options = ReplicateOptions.builder().withParameter("text", "machine learning").build(); + + StructuredResponse response = this.structuredModel.generate(options); + + assertThat(response).isNotNull(); + assertThat(response.getOutput()).isNotNull(); + } + +} diff --git a/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateTestConfiguration.java b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateTestConfiguration.java new file mode 100644 index 00000000000..bd43c454071 --- /dev/null +++ b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/ReplicateTestConfiguration.java @@ -0,0 +1,92 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.micrometer.observation.ObservationRegistry; + +import org.springframework.ai.replicate.api.ReplicateApi; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +/** + * Test configuration for Replicate integration tests. + * + * @author Rene Maierhofer + */ +@SpringBootConfiguration +public class ReplicateTestConfiguration { + + @Bean + public ReplicateApi replicateApi() { + return ReplicateApi.builder().apiKey(getApiKey()).build(); + } + + @Bean + public ReplicateChatModel replicateChatModel(ReplicateApi api, ObservationRegistry observationRegistry) { + return ReplicateChatModel.builder() + .replicateApi(api) + .observationRegistry(observationRegistry) + .defaultOptions(ReplicateChatOptions.builder().model("meta/meta-llama-3-8b-instruct").build()) + .build(); + } + + @Bean + public ReplicateMediaModel replicateMediaModel(ReplicateApi api) { + return ReplicateMediaModel.builder() + .replicateApi(api) + .defaultOptions(ReplicateOptions.builder().model("black-forest-labs/flux-schnell").build()) + .build(); + } + + @Bean + public ReplicateStringModel replicateStringModel(ReplicateApi api) { + return ReplicateStringModel.builder() + .replicateApi(api) + .defaultOptions(ReplicateOptions.builder().model("falcons-ai/nsfw_image_detection").build()) + .build(); + } + + @Bean + public ReplicateStructuredModel replicateStructuredModel(ReplicateApi api) { + return ReplicateStructuredModel.builder() + .replicateApi(api) + .defaultOptions(ReplicateOptions.builder().model("openai/clip").build()) + .build(); + } + + @Bean + public ObservationRegistry observationRegistry() { + return ObservationRegistry.create(); + } + + @Bean + public ObjectMapper objectMapper() { + return new ObjectMapper(); + } + + private String getApiKey() { + String apiKey = System.getenv("REPLICATE_API_TOKEN"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide a Replicate API token. Please set the REPLICATE_API_TOKEN environment variable."); + } + return apiKey; + } + +} diff --git a/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/api/ReplicateApiBuilderTests.java b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/api/ReplicateApiBuilderTests.java new file mode 100644 index 00000000000..efa59b3589a --- /dev/null +++ b/models/spring-ai-replicate/src/test/java/org/springframework/ai/replicate/api/ReplicateApiBuilderTests.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.replicate.api; + +import org.junit.jupiter.api.Test; + +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link ReplicateApi.Builder}. + * + * @author Rene Maierhofer + */ +class ReplicateApiBuilderTests { + + private static final String TEST_API_KEY = "someKey"; + + private static final String TEST_BASE_URL = "http://127.0.0.1"; + + @Test + void testBuilderWithOptions() { + RestClient.Builder restClientBuilder = RestClient.builder(); + WebClient.Builder webClientBuilder = WebClient.builder(); + ResponseErrorHandler errorHandler = mock(ResponseErrorHandler.class); + + ReplicateApi api = ReplicateApi.builder() + .apiKey(TEST_API_KEY) + .baseUrl(TEST_BASE_URL) + .restClientBuilder(restClientBuilder) + .webClientBuilder(webClientBuilder) + .responseErrorHandler(errorHandler) + .build(); + + assertThat(api).isNotNull(); + } + + @Test + void testBuilderWithoutApiKeyThrowsException() { + ReplicateApi.Builder builder = ReplicateApi.builder(); + assertThatThrownBy(builder::build).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("apiKey"); + } + + @Test + void testBuilderWithNullApiKeyThrowsException() { + ReplicateApi.Builder builder = ReplicateApi.builder(); + assertThatThrownBy(() -> builder.apiKey(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("ApiKey cannot be null"); + } + + @Test + void testBuilderWithEmptyBaseUrlThrowsException() { + ReplicateApi.Builder builder = ReplicateApi.builder(); + assertThatThrownBy(() -> builder.baseUrl("")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be empty"); + } + + @Test + void testBuilderWithNullBaseUrlThrowsException() { + ReplicateApi.Builder builder = ReplicateApi.builder(); + assertThatThrownBy(() -> builder.baseUrl(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("baseUrl cannot be empty"); + } + +} diff --git a/models/spring-ai-replicate/src/test/resources/test-image.jpg b/models/spring-ai-replicate/src/test/resources/test-image.jpg new file mode 100644 index 00000000000..a297f1869f7 Binary files /dev/null and b/models/spring-ai-replicate/src/test/resources/test-image.jpg differ diff --git a/pom.xml b/pom.xml index f3619e8a8c0..89bae4dc5bd 100644 --- a/pom.xml +++ b/pom.xml @@ -116,6 +116,7 @@ auto-configurations/models/spring-ai-autoconfigure-model-google-genai auto-configurations/models/spring-ai-autoconfigure-model-zhipuai auto-configurations/models/spring-ai-autoconfigure-model-deepseek + auto-configurations/models/spring-ai-autoconfigure-model-replicate auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-common auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-httpclient @@ -187,6 +188,7 @@ models/spring-ai-google-genai-embedding models/spring-ai-zhipuai models/spring-ai-deepseek + models/spring-ai-replicate spring-ai-spring-boot-starters/spring-ai-starter-model-anthropic spring-ai-spring-boot-starters/spring-ai-starter-model-azure-openai diff --git a/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java b/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java index 88105725a69..461b42a2a9b 100644 --- a/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java +++ b/spring-ai-commons/src/main/java/org/springframework/ai/observation/conventions/AiProvider.java @@ -85,6 +85,11 @@ public enum AiProvider { */ OPENAI("openai"), + /** + * AI system provided by Replicate + */ + REPLICATE("replicate"), + /** * AI system provided by Spring AI. */