28
28
import org .springframework .ai .chat .ChatResponse ;
29
29
import org .springframework .ai .chat .Generation ;
30
30
import org .springframework .ai .chat .StreamingChatClient ;
31
- import org .springframework .ai .chat .messages .Message ;
32
31
import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
33
32
import org .springframework .ai .chat .metadata .RateLimit ;
34
33
import org .springframework .ai .chat .prompt .Prompt ;
34
+ import org .springframework .ai .model .ModelOptionsUtils ;
35
35
import org .springframework .ai .openai .api .OpenAiApi ;
36
36
import org .springframework .ai .openai .api .OpenAiApi .ChatCompletion ;
37
37
import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage ;
38
+ import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionRequest ;
38
39
import org .springframework .ai .openai .api .OpenAiApi .OpenAiApiException ;
40
+ import org .springframework .ai .openai .api .OpenAiChatOptions ;
39
41
import org .springframework .ai .openai .metadata .OpenAiChatResponseMetadata ;
40
42
import org .springframework .ai .openai .metadata .support .OpenAiResponseHeaderExtractor ;
41
43
import org .springframework .http .ResponseEntity ;
57
59
*/
58
60
public class OpenAiChatClient implements ChatClient , StreamingChatClient {
59
61
60
- private Double temperature = 0.7 ;
62
+ private final Logger logger = LoggerFactory . getLogger ( getClass ()) ;
61
63
62
- private String model = "gpt-3.5-turbo" ;
64
+ private static final List <String > REQUEST_JSON_FIELD_NAMES = ModelOptionsUtils
65
+ .getJsonPropertyValues (ChatCompletionRequest .class );
63
66
64
- private final Logger logger = LoggerFactory .getLogger (getClass ());
67
+ private OpenAiChatOptions defaultOptions = OpenAiChatOptions .builder ()
68
+ .withModel ("gpt-3.5-turbo" )
69
+ .withTemperature (0.7f )
70
+ .build ();
65
71
66
72
public final RetryTemplate retryTemplate = RetryTemplate .builder ()
67
73
.maxAttempts (10 )
@@ -76,40 +82,23 @@ public OpenAiChatClient(OpenAiApi openAiApi) {
76
82
this .openAiApi = openAiApi ;
77
83
}
78
84
79
- public String getModel () {
80
- return this .model ;
81
- }
82
-
83
- public void setModel (String model ) {
84
- this .model = model ;
85
- }
86
-
87
- public Double getTemperature () {
88
- return this .temperature ;
89
- }
90
-
91
- public void setTemperature (Double temperature ) {
92
- this .temperature = temperature ;
85
+ public OpenAiChatClient withDefaultOptions (OpenAiChatOptions options ) {
86
+ this .defaultOptions = options ;
87
+ return this ;
93
88
}
94
89
95
90
@ Override
96
91
public ChatResponse call (Prompt prompt ) {
97
92
98
93
return this .retryTemplate .execute (ctx -> {
99
- List <Message > messages = prompt .getInstructions ();
100
94
101
- List <ChatCompletionMessage > chatCompletionMessages = messages .stream ()
102
- .map (m -> new ChatCompletionMessage (m .getContent (),
103
- ChatCompletionMessage .Role .valueOf (m .getMessageType ().name ())))
104
- .toList ();
95
+ ChatCompletionRequest request = createRequest (prompt , false );
105
96
106
- ResponseEntity <ChatCompletion > completionEntity = this .openAiApi
107
- .chatCompletionEntity (new OpenAiApi .ChatCompletionRequest (chatCompletionMessages , this .model ,
108
- this .temperature .floatValue ()));
97
+ ResponseEntity <ChatCompletion > completionEntity = this .openAiApi .chatCompletionEntity (request );
109
98
110
99
var chatCompletion = completionEntity .getBody ();
111
100
if (chatCompletion == null ) {
112
- logger .warn ("No chat completion returned for request: {}" , chatCompletionMessages );
101
+ logger .warn ("No chat completion returned for request: {}" , prompt );
113
102
return new ChatResponse (List .of ());
114
103
}
115
104
@@ -128,16 +117,9 @@ public ChatResponse call(Prompt prompt) {
128
117
@ Override
129
118
public Flux <ChatResponse > stream (Prompt prompt ) {
130
119
return this .retryTemplate .execute (ctx -> {
131
- List <Message > messages = prompt .getInstructions ();
132
-
133
- List <ChatCompletionMessage > chatCompletionMessages = messages .stream ()
134
- .map (m -> new ChatCompletionMessage (m .getContent (),
135
- ChatCompletionMessage .Role .valueOf (m .getMessageType ().name ())))
136
- .toList ();
120
+ ChatCompletionRequest request = createRequest (prompt , true );
137
121
138
- Flux <OpenAiApi .ChatCompletionChunk > completionChunks = this .openAiApi
139
- .chatCompletionStream (new OpenAiApi .ChatCompletionRequest (chatCompletionMessages , this .model ,
140
- this .temperature .floatValue (), true ));
122
+ Flux <OpenAiApi .ChatCompletionChunk > completionChunks = this .openAiApi .chatCompletionStream (request );
141
123
142
124
// For chunked responses, only the first chunk contains the choice role.
143
125
// The rest of the chunks with same ID share the same role.
@@ -161,4 +143,36 @@ public Flux<ChatResponse> stream(Prompt prompt) {
161
143
});
162
144
}
163
145
146
+ /**
147
+ * Accessible for testing.
148
+ */
149
+ ChatCompletionRequest createRequest (Prompt prompt , boolean stream ) {
150
+
151
+ List <ChatCompletionMessage > chatCompletionMessages = prompt .getInstructions ()
152
+ .stream ()
153
+ .map (m -> new ChatCompletionMessage (m .getContent (),
154
+ ChatCompletionMessage .Role .valueOf (m .getMessageType ().name ())))
155
+ .toList ();
156
+
157
+ ChatCompletionRequest request = new ChatCompletionRequest (chatCompletionMessages , stream );
158
+
159
+ if (this .defaultOptions != null ) {
160
+ request = ModelOptionsUtils .merge (request , this .defaultOptions , ChatCompletionRequest .class ,
161
+ REQUEST_JSON_FIELD_NAMES );
162
+ }
163
+
164
+ if (prompt .getOptions () != null ) {
165
+ if (prompt .getOptions () instanceof OpenAiChatOptions runtimeOptions ) {
166
+ request = ModelOptionsUtils .merge (runtimeOptions , request , ChatCompletionRequest .class ,
167
+ REQUEST_JSON_FIELD_NAMES );
168
+ }
169
+ else {
170
+ throw new IllegalArgumentException ("Prompt options are not of type ChatCompletionRequest:"
171
+ + prompt .getOptions ().getClass ().getSimpleName ());
172
+ }
173
+ }
174
+
175
+ return request ;
176
+ }
177
+
164
178
}
0 commit comments