Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions models/spring-ai-openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
<version>2.0.4</version>
</dependency>

<!-- NOTE: Required only by the @ConstructorBinding. -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot</artifactId>
</dependency>

<dependency>
<groupId>io.rest-assured</groupId>
<artifactId>json-path</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
Expand All @@ -57,11 +58,15 @@
*/
public class OpenAiChatClient implements ChatClient, StreamingChatClient {

private Double temperature = 0.7;
private final Logger logger = LoggerFactory.getLogger(getClass());

private String model = "gpt-3.5-turbo";
private static final List<String> REQUEST_JSON_FIELD_NAMES = ModelOptionsUtils
.getJsonPropertyValues(ChatCompletionRequest.class);

private final Logger logger = LoggerFactory.getLogger(getClass());
private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder()
.withModel("gpt-3.5-turbo")
.withTemperature(0.7f)
.build();

public final RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(10)
Expand All @@ -76,40 +81,23 @@ public OpenAiChatClient(OpenAiApi openAiApi) {
this.openAiApi = openAiApi;
}

public String getModel() {
return this.model;
}

public void setModel(String model) {
this.model = model;
}

public Double getTemperature() {
return this.temperature;
}

public void setTemperature(Double temperature) {
this.temperature = temperature;
public OpenAiChatClient withDefaultOptions(OpenAiChatOptions options) {
this.defaultOptions = options;
return this;
}

@Override
public ChatResponse call(Prompt prompt) {

return this.retryTemplate.execute(ctx -> {
List<Message> messages = prompt.getInstructions();

List<ChatCompletionMessage> chatCompletionMessages = messages.stream()
.map(m -> new ChatCompletionMessage(m.getContent(),
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
.toList();
ChatCompletionRequest request = createRequest(prompt, false);

ResponseEntity<ChatCompletion> completionEntity = this.openAiApi
.chatCompletionEntity(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model,
this.temperature.floatValue()));
ResponseEntity<ChatCompletion> completionEntity = this.openAiApi.chatCompletionEntity(request);

var chatCompletion = completionEntity.getBody();
if (chatCompletion == null) {
logger.warn("No chat completion returned for request: {}", chatCompletionMessages);
logger.warn("No chat completion returned for request: {}", prompt);
return new ChatResponse(List.of());
}

Expand All @@ -128,16 +116,9 @@ public ChatResponse call(Prompt prompt) {
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.retryTemplate.execute(ctx -> {
List<Message> messages = prompt.getInstructions();

List<ChatCompletionMessage> chatCompletionMessages = messages.stream()
.map(m -> new ChatCompletionMessage(m.getContent(),
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
.toList();
ChatCompletionRequest request = createRequest(prompt, true);

Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi
.chatCompletionStream(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model,
this.temperature.floatValue(), true));
Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request);

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
Expand All @@ -161,4 +142,36 @@ public Flux<ChatResponse> stream(Prompt prompt) {
});
}

/**
* Accessible for testing.
*/
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
.stream()
.map(m -> new ChatCompletionMessage(m.getContent(),
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
.toList();

ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);

if (this.defaultOptions != null) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class,
REQUEST_JSON_FIELD_NAMES);
}

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) {
request = ModelOptionsUtils.merge(runtimeOptions, request, ChatCompletionRequest.class,
REQUEST_JSON_FIELD_NAMES);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatCompletionRequest:"
+ prompt.getOptions().getClass().getSimpleName());
}
}

return request;
}

}
Loading