Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions models/spring-ai-openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
<version>2.0.4</version>
</dependency>

<!-- NOTE: Required only by the @ConstructorBinding. -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot</artifactId>
</dependency>

<dependency>
<groupId>io.rest-assured</groupId>
<artifactId>json-path</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
import org.springframework.ai.openai.api.OpenAiChatOptions;
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.http.ResponseEntity;
Expand All @@ -57,12 +59,13 @@
*/
public class OpenAiChatClient implements ChatClient, StreamingChatClient {

private Double temperature = 0.7;

private String model = "gpt-3.5-turbo";

private final Logger logger = LoggerFactory.getLogger(getClass());

private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder()
.withModel("gpt-3.5-turbo")
.withTemperature(0.7f)
.build();

public final RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(10)
.retryOn(OpenAiApiException.class)
Expand All @@ -76,40 +79,46 @@ public OpenAiChatClient(OpenAiApi openAiApi) {
this.openAiApi = openAiApi;
}

public String getModel() {
return this.model;
}

public void setModel(String model) {
this.model = model;
}

public Double getTemperature() {
return this.temperature;
}

public void setTemperature(Double temperature) {
this.temperature = temperature;
public OpenAiChatClient withDefaultOptions(OpenAiChatOptions options) {
this.defaultOptions = options;
return this;
}

@Override
public ChatResponse call(Prompt prompt) {

return this.retryTemplate.execute(ctx -> {
List<Message> messages = prompt.getInstructions();

List<ChatCompletionMessage> chatCompletionMessages = messages.stream()
.map(m -> new ChatCompletionMessage(m.getContent(),
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
.toList();
ChatCompletionRequest request = createRequest(prompt, false);

// List<Message> messages = prompt.getInstructions();

// List<ChatCompletionMessage> chatCompletionMessages = messages.stream()
// .map(m -> new ChatCompletionMessage(m.getContent(),
// ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
// .toList();

// ChatCompletionRequest request =
// ChatCompletionRequest.from(chatCompletionMessages, this.defaultOptions,
// false);

ResponseEntity<ChatCompletion> completionEntity = this.openAiApi
.chatCompletionEntity(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model,
this.temperature.floatValue()));
// if (prompt.getOptions() != null) {
// if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) {
// request = ModelOptionsUtils.merge(runtimeOptions, request,
// ChatCompletionRequest.class);
// }
// else {
// throw new IllegalArgumentException("Prompt options are not of type
// ChatCompletionRequest:"
// + prompt.getOptions().getClass().getSimpleName());
// }
// }

ResponseEntity<ChatCompletion> completionEntity = this.openAiApi.chatCompletionEntity(request);

var chatCompletion = completionEntity.getBody();
if (chatCompletion == null) {
logger.warn("No chat completion returned for request: {}", chatCompletionMessages);
logger.warn("No chat completion returned for request: {}", prompt);
return new ChatResponse(List.of());
}

Expand All @@ -128,16 +137,32 @@ public ChatResponse call(Prompt prompt) {
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.retryTemplate.execute(ctx -> {
List<Message> messages = prompt.getInstructions();
ChatCompletionRequest request = createRequest(prompt, true);

// List<Message> messages = prompt.getInstructions();

// List<ChatCompletionMessage> chatCompletionMessages = messages.stream()
// .map(m -> new ChatCompletionMessage(m.getContent(),
// ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
// .toList();

List<ChatCompletionMessage> chatCompletionMessages = messages.stream()
.map(m -> new ChatCompletionMessage(m.getContent(),
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
.toList();
// ChatCompletionRequest request =
// ChatCompletionRequest.from(chatCompletionMessages, this.defaultOptions,
// true);

Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi
.chatCompletionStream(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model,
this.temperature.floatValue(), true));
// if (prompt.getOptions() != null) {
// if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) {
// request = ModelOptionsUtils.merge(runtimeOptions, request,
// ChatCompletionRequest.class);
// }
// else {
// throw new IllegalArgumentException("Prompt options are not of type
// ChatCompletionRequest:"
// + prompt.getOptions().getClass().getSimpleName());
// }
// }

Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request);

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
Expand All @@ -161,4 +186,30 @@ public Flux<ChatResponse> stream(Prompt prompt) {
});
}

/**
* Accessible for testing.
*/
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
.stream()
.map(m -> new ChatCompletionMessage(m.getContent(),
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
.toList();

ChatCompletionRequest request = ChatCompletionRequest.from(chatCompletionMessages, this.defaultOptions, stream);

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) {
request = ModelOptionsUtils.merge(runtimeOptions, request, ChatCompletionRequest.class);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatCompletionRequest:"
+ prompt.getOptions().getClass().getSimpleName());
}
}

return request;
}

}
Loading