diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java index 5decb221b0b..1dc941c7759 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java @@ -183,7 +183,7 @@ public ChatResponse call(Prompt prompt) { } @Override - public Flux generateStream(Prompt prompt) { + public Flux stream(Prompt prompt) { ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); options.setStream(true); diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java index d82db85c31e..a65df0c955d 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java @@ -151,7 +151,7 @@ void beanStreamOutputParserRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = chatClient.generateStream(prompt) + String generationTextFromStream = chatClient.stream(prompt) .collectList() .block() .stream() diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java index a8b67702648..c9e87d34fef 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java @@ -107,7 +107,7 @@ public ChatResponse call(Prompt prompt) { } @Override - public Flux generateStream(Prompt prompt) { + public Flux stream(Prompt prompt) { final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java index a5c1e3e188c..bbe2c800ac1 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java @@ -122,7 +122,7 @@ public ChatResponse call(Prompt prompt) { } @Override - public Flux generateStream(Prompt prompt) { + public Flux stream(Prompt prompt) { return this.chatApi.chatCompletionStream(this.createRequest(prompt, true)).map(g -> { if (g.isFinished()) { String finishReason = g.finishReason().name(); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java index 38e157394af..d1641ca58ff 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java @@ -85,7 +85,7 @@ public ChatResponse call(Prompt prompt) { } @Override - public Flux generateStream(Prompt prompt) { + public Flux stream(Prompt prompt) { final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java index b2279ce836b..c1a78104936 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java @@ -84,7 +84,7 @@ public ChatResponse call(Prompt prompt) { } @Override - public Flux generateStream(Prompt prompt) { + public Flux stream(Prompt prompt) { return this.chatApi.chatCompletionStream(this.createRequest(prompt, true)).map(chunk -> { Generation generation = new Generation(chunk.outputText()); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java index fd342892766..8eaa3d2bc7f 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClientIT.java @@ -134,7 +134,7 @@ void beanStreamOutputParserRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = client.generateStream(prompt) + String generationTextFromStream = client.stream(prompt) .collectList() .block() .stream() diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java index 9fc97ee2c6f..73b88766bce 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClientIT.java @@ -134,7 +134,7 @@ void beanStreamOutputParserRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = client.generateStream(prompt) + String generationTextFromStream = client.stream(prompt) .collectList() .block() .stream() diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java index b1ec0458e17..951d2240672 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClientIT.java @@ -139,7 +139,7 @@ void beanStreamOutputParserRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = client.generateStream(prompt) + String generationTextFromStream = client.stream(prompt) .collectList() .block() .stream() diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java index 1c07073cc0e..9d98681a68c 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClientIT.java @@ -140,7 +140,7 @@ void beanStreamOutputParserRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = client.generateStream(prompt) + String generationTextFromStream = client.stream(prompt) .collectList() .block() .stream() diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java index e1835c67ac1..8d8d7623921 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java @@ -97,7 +97,7 @@ public ChatResponse call(Prompt prompt) { } @Override - public Flux generateStream(Prompt prompt) { + public Flux stream(Prompt prompt) { Flux response = this.chatApi.streamingChat(request(prompt, this.model, true)); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java index 84617fea0bc..dfff01a4528 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java @@ -167,7 +167,7 @@ void beanStreamOutputParserRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = client.generateStream(prompt) + String generationTextFromStream = client.stream(prompt) .collectList() .block() .stream() diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index 64ebc1c45d0..b6da086f348 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -127,7 +127,7 @@ public ChatResponse call(Prompt prompt) { } @Override - public Flux generateStream(Prompt prompt) { + public Flux stream(Prompt prompt) { return this.retryTemplate.execute(ctx -> { List messages = prompt.getInstructions(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java index aa35ba9e04b..a29617c5ceb 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java @@ -140,7 +140,7 @@ void beanStreamOutputParserRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = openStreamingChatClient.generateStream(prompt) + String generationTextFromStream = openStreamingChatClient.stream(prompt) .collectList() .block() .stream() diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java index beb67ae6a3c..a650c9d68d8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/StreamingChatClient.java @@ -16,13 +16,10 @@ package org.springframework.ai.chat; -import reactor.core.publisher.Flux; - import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.StreamingModelClient; @FunctionalInterface -public interface StreamingChatClient { - - Flux generateStream(Prompt prompt); +public interface StreamingChatClient extends StreamingModelClient { } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModelClient.java b/spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModelClient.java new file mode 100644 index 00000000000..da1db15048b --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/StreamingModelClient.java @@ -0,0 +1,43 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model; + +import reactor.core.publisher.Flux; + +/** + * The StreamingModelClient interface provides a generic API for invoking a AI models with + * streaming response. It abstracts the process of sending requests and receiving a + * streaming responses. The interface uses Java generics to accommodate different types of + * requests and responses, enhancing flexibility and adaptability across different AI + * model implementations. + * + * @param the generic type of the request to the AI model + * @param the generic type of a single item in the streaming response from the + * AI model + * @author Christian Tzolov + * @since 0.8.0 + */ +public interface StreamingModelClient, TResChunk extends ModelResponse> { + + /** + * Executes a method call to the AI model. + * @param request the request object to be sent to the AI model + * @return the streaming response from the AI model + */ + Flux stream(TReq request); + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index 4184358ab7d..b195f9ee0f3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -90,7 +90,7 @@ public void chatCompletionStreaming() { AzureOpenAiChatClient chatClient = context.getBean(AzureOpenAiChatClient.class); - Flux response = chatClient.generateStream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = chatClient.stream(new Prompt(List.of(userMessage, systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(1); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java index 8f2a2eaff84..e855b096f95 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/anthropic/BedrockAnthropicChatAutoConfigurationIT.java @@ -83,7 +83,7 @@ public void chatCompletionStreaming() { BedrockAnthropicChatClient anthropicChatClient = context.getBean(BedrockAnthropicChatClient.class); Flux response = anthropicChatClient - .generateStream(new Prompt(List.of(userMessage, systemMessage))); + .stream(new Prompt(List.of(userMessage, systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java index b40d337afba..a022d6a6782 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/cohere/BedrockCohereChatAutoConfigurationIT.java @@ -85,7 +85,7 @@ public void chatCompletionStreaming() { BedrockCohereChatClient cohereChatClient = context.getBean(BedrockCohereChatClient.class); Flux response = cohereChatClient - .generateStream(new Prompt(List.of(userMessage, systemMessage))); + .stream(new Prompt(List.of(userMessage, systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java index 97435c61d50..e37171b0421 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/llama2/BedrockLlama2ChatAutoConfigurationIT.java @@ -83,7 +83,7 @@ public void chatCompletionStreaming() { BedrockLlama2ChatClient llama2ChatClient = context.getBean(BedrockLlama2ChatClient.class); Flux response = llama2ChatClient - .generateStream(new Prompt(List.of(userMessage, systemMessage))); + .stream(new Prompt(List.of(userMessage, systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(2); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java index 44e7ef97503..20ac6cdfb0a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/titan/BedrockTitanChatAutoConfigurationIT.java @@ -82,7 +82,7 @@ public void chatCompletionStreaming() { BedrockTitanChatClient chatClient = context.getBean(BedrockTitanChatClient.class); - Flux response = chatClient.generateStream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = chatClient.stream(new Prompt(List.of(userMessage, systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(1); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java index 0679da52dca..eda449afefb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfigurationIT.java @@ -100,7 +100,7 @@ public void chatCompletionStreaming() { OllamaChatClient chatClient = context.getBean(OllamaChatClient.class); - Flux response = chatClient.generateStream(new Prompt(List.of(userMessage, systemMessage))); + Flux response = chatClient.stream(new Prompt(List.of(userMessage, systemMessage))); List responses = response.collectList().block(); assertThat(responses.size()).isGreaterThan(1); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java index 748bc822a5d..d7e815296d2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java @@ -59,7 +59,7 @@ void generate() { void generateStreaming() { contextRunner.run(context -> { OpenAiChatClient client = context.getBean(OpenAiChatClient.class); - Flux responseFlux = client.generateStream(new Prompt(new UserMessage("Hello"))); + Flux responseFlux = client.stream(new Prompt(new UserMessage("Hello"))); String response = responseFlux.collectList().block().stream().map(chatResponse -> { return chatResponse.getResults().get(0).getOutput().getContent(); }).collect(Collectors.joining());