Skip to content

Commit 0947fcc

Browse files
committed
Add OpenAi Chat Options (2)
- Add OpenAiChatOptions that implements ChatOptions and exposes all OpenAi request options, except messages and stream. - Add OpenAiChatOptions field (as defaultOptions) to OpenAiChatClient. Implement star-up/runtime options merging on chat request creation - Add OpenAiChatOptions options field to OpenAiChatProperties. Later is set as OpenAiChatClient#defaultOptions. Use the spring.ai.openai.chat.options.* prefix to set the options. - Add tests for properties and options merging. Part of #228
1 parent 9381ab6 commit 0947fcc

File tree

11 files changed

+754
-69
lines changed

11 files changed

+754
-69
lines changed

models/spring-ai-openai/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@
3434
<version>2.0.4</version>
3535
</dependency>
3636

37+
<!-- NOTE: Required only by the @ConstructorBinding. -->
38+
<dependency>
39+
<groupId>org.springframework.boot</groupId>
40+
<artifactId>spring-boot</artifactId>
41+
</dependency>
42+
3743
<dependency>
3844
<groupId>io.rest-assured</groupId>
3945
<artifactId>json-path</artifactId>

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,16 @@
2828
import org.springframework.ai.chat.ChatResponse;
2929
import org.springframework.ai.chat.Generation;
3030
import org.springframework.ai.chat.StreamingChatClient;
31-
import org.springframework.ai.chat.messages.Message;
3231
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3332
import org.springframework.ai.chat.metadata.RateLimit;
3433
import org.springframework.ai.chat.prompt.Prompt;
34+
import org.springframework.ai.model.ModelOptionsUtils;
3535
import org.springframework.ai.openai.api.OpenAiApi;
3636
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
3737
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
38+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
3839
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
40+
import org.springframework.ai.openai.api.OpenAiChatOptions;
3941
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
4042
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
4143
import org.springframework.http.ResponseEntity;
@@ -57,11 +59,15 @@
5759
*/
5860
public class OpenAiChatClient implements ChatClient, StreamingChatClient {
5961

60-
private Double temperature = 0.7;
62+
private final Logger logger = LoggerFactory.getLogger(getClass());
6163

62-
private String model = "gpt-3.5-turbo";
64+
private static final List<String> REQUEST_JSON_FIELD_NAMES = ModelOptionsUtils
65+
.getJsonPropertyValues(ChatCompletionRequest.class);
6366

64-
private final Logger logger = LoggerFactory.getLogger(getClass());
67+
private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder()
68+
.withModel("gpt-3.5-turbo")
69+
.withTemperature(0.7f)
70+
.build();
6571

6672
public final RetryTemplate retryTemplate = RetryTemplate.builder()
6773
.maxAttempts(10)
@@ -76,40 +82,23 @@ public OpenAiChatClient(OpenAiApi openAiApi) {
7682
this.openAiApi = openAiApi;
7783
}
7884

79-
public String getModel() {
80-
return this.model;
81-
}
82-
83-
public void setModel(String model) {
84-
this.model = model;
85-
}
86-
87-
public Double getTemperature() {
88-
return this.temperature;
89-
}
90-
91-
public void setTemperature(Double temperature) {
92-
this.temperature = temperature;
85+
public OpenAiChatClient withDefaultOptions(OpenAiChatOptions options) {
86+
this.defaultOptions = options;
87+
return this;
9388
}
9489

9590
@Override
9691
public ChatResponse call(Prompt prompt) {
9792

9893
return this.retryTemplate.execute(ctx -> {
99-
List<Message> messages = prompt.getInstructions();
10094

101-
List<ChatCompletionMessage> chatCompletionMessages = messages.stream()
102-
.map(m -> new ChatCompletionMessage(m.getContent(),
103-
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
104-
.toList();
95+
ChatCompletionRequest request = createRequest(prompt, false);
10596

106-
ResponseEntity<ChatCompletion> completionEntity = this.openAiApi
107-
.chatCompletionEntity(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model,
108-
this.temperature.floatValue()));
97+
ResponseEntity<ChatCompletion> completionEntity = this.openAiApi.chatCompletionEntity(request);
10998

11099
var chatCompletion = completionEntity.getBody();
111100
if (chatCompletion == null) {
112-
logger.warn("No chat completion returned for request: {}", chatCompletionMessages);
101+
logger.warn("No chat completion returned for request: {}", prompt);
113102
return new ChatResponse(List.of());
114103
}
115104

@@ -128,16 +117,9 @@ public ChatResponse call(Prompt prompt) {
128117
@Override
129118
public Flux<ChatResponse> stream(Prompt prompt) {
130119
return this.retryTemplate.execute(ctx -> {
131-
List<Message> messages = prompt.getInstructions();
132-
133-
List<ChatCompletionMessage> chatCompletionMessages = messages.stream()
134-
.map(m -> new ChatCompletionMessage(m.getContent(),
135-
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
136-
.toList();
120+
ChatCompletionRequest request = createRequest(prompt, true);
137121

138-
Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi
139-
.chatCompletionStream(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model,
140-
this.temperature.floatValue(), true));
122+
Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request);
141123

142124
// For chunked responses, only the first chunk contains the choice role.
143125
// The rest of the chunks with same ID share the same role.
@@ -161,4 +143,36 @@ public Flux<ChatResponse> stream(Prompt prompt) {
161143
});
162144
}
163145

146+
/**
147+
* Accessible for testing.
148+
*/
149+
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
150+
151+
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
152+
.stream()
153+
.map(m -> new ChatCompletionMessage(m.getContent(),
154+
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
155+
.toList();
156+
157+
ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);
158+
159+
if (this.defaultOptions != null) {
160+
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class,
161+
REQUEST_JSON_FIELD_NAMES);
162+
}
163+
164+
if (prompt.getOptions() != null) {
165+
if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) {
166+
request = ModelOptionsUtils.merge(runtimeOptions, request, ChatCompletionRequest.class,
167+
REQUEST_JSON_FIELD_NAMES);
168+
}
169+
else {
170+
throw new IllegalArgumentException("Prompt options are not of type ChatCompletionRequest:"
171+
+ prompt.getOptions().getClass().getSimpleName());
172+
}
173+
}
174+
175+
return request;
176+
}
177+
164178
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import reactor.core.publisher.Flux;
3232
import reactor.core.publisher.Mono;
3333

34+
import org.springframework.boot.context.properties.bind.ConstructorBinding;
3435
import org.springframework.core.ParameterizedTypeReference;
3536
import org.springframework.http.HttpHeaders;
3637
import org.springframework.http.MediaType;
@@ -183,6 +184,7 @@ public record FunctionTool(
183184
* Create a tool of type 'function' and the given function definition.
184185
* @param function function definition.
185186
*/
187+
@ConstructorBinding
186188
public FunctionTool(Function function) {
187189
this(Type.FUNCTION, function);
188190
}
@@ -219,13 +221,14 @@ public record Function(
219221
* @param name tool function name.
220222
* @param jsonSchema tool function schema as json.
221223
*/
224+
@ConstructorBinding
222225
public Function(String description, String name, String jsonSchema) {
223226
this(description, name, parseJson(jsonSchema));
224227
}
225228
}
226229
}
227230

228-
/**
231+
/**
229232
* Creates a model response for the given chat conversation.
230233
*
231234
* @param messages A list of messages comprising the conversation so far.
@@ -269,17 +272,17 @@ public Function(String description, String name, String jsonSchema) {
269272
*
270273
*/
271274
@JsonInclude(Include.NON_NULL)
272-
public record ChatCompletionRequest(
275+
public record ChatCompletionRequest (
273276
@JsonProperty("messages") List<ChatCompletionMessage> messages,
274277
@JsonProperty("model") String model,
275278
@JsonProperty("frequency_penalty") Float frequencyPenalty,
276-
@JsonProperty("logit_bias") Map<String, Object> logitBias,
279+
@JsonProperty("logit_bias") Map<String, Integer> logitBias,
277280
@JsonProperty("max_tokens") Integer maxTokens,
278281
@JsonProperty("n") Integer n,
279282
@JsonProperty("presence_penalty") Float presencePenalty,
280283
@JsonProperty("response_format") ResponseFormat responseFormat,
281284
@JsonProperty("seed") Integer seed,
282-
@JsonProperty("stop") String stop,
285+
@JsonProperty("stop") List<String> stop,
283286
@JsonProperty("stream") Boolean stream,
284287
@JsonProperty("temperature") Float temperature,
285288
@JsonProperty("top_p") Float topP,
@@ -331,6 +334,20 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
331334
tools, toolChoice, null);
332335
}
333336

337+
/**
338+
* Shortcut constructor for a chat completion request with the given messages, model, tools and tool choice.
339+
* Streaming is set to false, temperature to 0.8 and all other parameters are null.
340+
*
341+
* @param messages A list of messages comprising the conversation so far.
342+
* @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events
343+
* as they become available, with the stream terminated by a data: [DONE] message.
344+
*/
345+
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
346+
this(messages, null, null, null, null, null, null,
347+
null, null, null, stream, null, null,
348+
null, null, null);
349+
}
350+
334351
/**
335352
* Specifies a tool the model should use. Use to force the model to call a specific function.
336353
*
@@ -346,6 +363,7 @@ public record ToolChoice(
346363
* Create a tool choice of type 'function' and name 'functionName'.
347364
* @param functionName Function name of the tool.
348365
*/
366+
@ConstructorBinding
349367
public ToolChoice(String functionName) {
350368
this("function", Map.of("name", functionName));
351369
}

0 commit comments

Comments
 (0)