Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public ChatResponse call(Prompt prompt) {
}

@Override
public Flux<ChatResponse> generateStream(Prompt prompt) {
public Flux<ChatResponse> stream(Prompt prompt) {

ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ void beanStreamOutputParserRecords() {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());

String generationTextFromStream = chatClient.generateStream(prompt)
String generationTextFromStream = chatClient.stream(prompt)
.collectList()
.block()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public ChatResponse call(Prompt prompt) {
}

@Override
public Flux<ChatResponse> generateStream(Prompt prompt) {
public Flux<ChatResponse> stream(Prompt prompt) {

final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public ChatResponse call(Prompt prompt) {
}

@Override
public Flux<ChatResponse> generateStream(Prompt prompt) {
public Flux<ChatResponse> stream(Prompt prompt) {
return this.chatApi.chatCompletionStream(this.createRequest(prompt, true)).map(g -> {
if (g.isFinished()) {
String finishReason = g.finishReason().name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public ChatResponse call(Prompt prompt) {
}

@Override
public Flux<ChatResponse> generateStream(Prompt prompt) {
public Flux<ChatResponse> stream(Prompt prompt) {

final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public ChatResponse call(Prompt prompt) {
}

@Override
public Flux<ChatResponse> generateStream(Prompt prompt) {
public Flux<ChatResponse> stream(Prompt prompt) {
return this.chatApi.chatCompletionStream(this.createRequest(prompt, true)).map(chunk -> {

Generation generation = new Generation(chunk.outputText());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ void beanStreamOutputParserRecords() {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());

String generationTextFromStream = client.generateStream(prompt)
String generationTextFromStream = client.stream(prompt)
.collectList()
.block()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ void beanStreamOutputParserRecords() {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());

String generationTextFromStream = client.generateStream(prompt)
String generationTextFromStream = client.stream(prompt)
.collectList()
.block()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ void beanStreamOutputParserRecords() {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());

String generationTextFromStream = client.generateStream(prompt)
String generationTextFromStream = client.stream(prompt)
.collectList()
.block()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ void beanStreamOutputParserRecords() {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());

String generationTextFromStream = client.generateStream(prompt)
String generationTextFromStream = client.stream(prompt)
.collectList()
.block()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public ChatResponse call(Prompt prompt) {
}

@Override
public Flux<ChatResponse> generateStream(Prompt prompt) {
public Flux<ChatResponse> stream(Prompt prompt) {

Flux<OllamaApi.ChatResponse> response = this.chatApi.streamingChat(request(prompt, this.model, true));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ void beanStreamOutputParserRecords() {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());

String generationTextFromStream = client.generateStream(prompt)
String generationTextFromStream = client.stream(prompt)
.collectList()
.block()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public ChatResponse call(Prompt prompt) {
}

@Override
public Flux<ChatResponse> generateStream(Prompt prompt) {
public Flux<ChatResponse> stream(Prompt prompt) {
return this.retryTemplate.execute(ctx -> {
List<Message> messages = prompt.getInstructions();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ void beanStreamOutputParserRecords() {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());

String generationTextFromStream = openStreamingChatClient.generateStream(prompt)
String generationTextFromStream = openStreamingChatClient.stream(prompt)
.collectList()
.block()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@

package org.springframework.ai.chat;

import reactor.core.publisher.Flux;

import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.StreamingModelClient;

@FunctionalInterface
public interface StreamingChatClient {

Flux<ChatResponse> generateStream(Prompt prompt);
public interface StreamingChatClient extends StreamingModelClient<Prompt, ChatResponse> {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.model;

import reactor.core.publisher.Flux;

/**
* The StreamingModelClient interface provides a generic API for invoking a AI models with
* streaming response. It abstracts the process of sending requests and receiving a
* streaming responses. The interface uses Java generics to accommodate different types of
* requests and responses, enhancing flexibility and adaptability across different AI
* model implementations.
*
* @param <TReq> the generic type of the request to the AI model
* @param <TResChunk> the generic type of a single item in the streaming response from the
* AI model
* @author Christian Tzolov
* @since 0.8.0
*/
public interface StreamingModelClient<TReq extends ModelRequest<?>, TResChunk extends ModelResponse<?>> {

/**
* Executes a method call to the AI model.
* @param request the request object to be sent to the AI model
* @return the streaming response from the AI model
*/
Flux<TResChunk> stream(TReq request);

}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public void chatCompletionStreaming() {

AzureOpenAiChatClient chatClient = context.getBean(AzureOpenAiChatClient.class);

Flux<ChatResponse> response = chatClient.generateStream(new Prompt(List.of(userMessage, systemMessage)));
Flux<ChatResponse> response = chatClient.stream(new Prompt(List.of(userMessage, systemMessage)));

List<ChatResponse> responses = response.collectList().block();
assertThat(responses.size()).isGreaterThan(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public void chatCompletionStreaming() {
BedrockAnthropicChatClient anthropicChatClient = context.getBean(BedrockAnthropicChatClient.class);

Flux<ChatResponse> response = anthropicChatClient
.generateStream(new Prompt(List.of(userMessage, systemMessage)));
.stream(new Prompt(List.of(userMessage, systemMessage)));

List<ChatResponse> responses = response.collectList().block();
assertThat(responses.size()).isGreaterThan(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public void chatCompletionStreaming() {
BedrockCohereChatClient cohereChatClient = context.getBean(BedrockCohereChatClient.class);

Flux<ChatResponse> response = cohereChatClient
.generateStream(new Prompt(List.of(userMessage, systemMessage)));
.stream(new Prompt(List.of(userMessage, systemMessage)));

List<ChatResponse> responses = response.collectList().block();
assertThat(responses.size()).isGreaterThan(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public void chatCompletionStreaming() {
BedrockLlama2ChatClient llama2ChatClient = context.getBean(BedrockLlama2ChatClient.class);

Flux<ChatResponse> response = llama2ChatClient
.generateStream(new Prompt(List.of(userMessage, systemMessage)));
.stream(new Prompt(List.of(userMessage, systemMessage)));

List<ChatResponse> responses = response.collectList().block();
assertThat(responses.size()).isGreaterThan(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public void chatCompletionStreaming() {

BedrockTitanChatClient chatClient = context.getBean(BedrockTitanChatClient.class);

Flux<ChatResponse> response = chatClient.generateStream(new Prompt(List.of(userMessage, systemMessage)));
Flux<ChatResponse> response = chatClient.stream(new Prompt(List.of(userMessage, systemMessage)));

List<ChatResponse> responses = response.collectList().block();
assertThat(responses.size()).isGreaterThan(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public void chatCompletionStreaming() {

OllamaChatClient chatClient = context.getBean(OllamaChatClient.class);

Flux<ChatResponse> response = chatClient.generateStream(new Prompt(List.of(userMessage, systemMessage)));
Flux<ChatResponse> response = chatClient.stream(new Prompt(List.of(userMessage, systemMessage)));

List<ChatResponse> responses = response.collectList().block();
assertThat(responses.size()).isGreaterThan(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void generate() {
void generateStreaming() {
contextRunner.run(context -> {
OpenAiChatClient client = context.getBean(OpenAiChatClient.class);
Flux<ChatResponse> responseFlux = client.generateStream(new Prompt(new UserMessage("Hello")));
Flux<ChatResponse> responseFlux = client.stream(new Prompt(new UserMessage("Hello")));
String response = responseFlux.collectList().block().stream().map(chatResponse -> {
return chatResponse.getResults().get(0).getOutput().getContent();
}).collect(Collectors.joining());
Expand Down