Skip to content

Commit 20e4b56

Browse files
didalgolabtzolov
andcommitted
OpenAI: Add support for streamin token usage.
- OpenAiApi: add StreamingOptions class and ChatCompletionRequest#streamingOptions field. - add OpenAiChatOption#withStreamingUsage(boolean) to set/unset the StreamingOptions. - add a boolean (get/set)StreamUsage() to OpenAiChatOptions that internally set the SstreamOptions. Later allows the "spring.ai.openai.chat.options.stream-usage" property. - update the OpenAI property documentation. Co-authored-by: Christian Tzolov <[email protected]>
1 parent 12dbc1e commit 20e4b56

File tree

9 files changed

+138
-16
lines changed

9 files changed

+138
-16
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,12 @@ public Flux<ChatResponse> stream(Prompt prompt) {
222222
return generation;
223223
}).toList();
224224

225-
return new ChatResponse(generations);
225+
if (chatCompletion.usage() != null) {
226+
return new ChatResponse(generations, OpenAiChatResponseMetadata.from(chatCompletion));
227+
}
228+
else {
229+
return new ChatResponse(generations);
230+
}
226231
}
227232
catch (Exception e) {
228233
logger.error("Error processing chat completion", e);
@@ -245,7 +250,7 @@ private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionC
245250
.toList();
246251

247252
return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(),
248-
chunk.systemFingerprint(), "chat.completion", null);
253+
chunk.systemFingerprint(), "chat.completion", chunk.usage());
249254
}
250255

251256
/**
@@ -306,6 +311,12 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
306311
request, ChatCompletionRequest.class);
307312
}
308313

314+
// Remove `streamOptions` from the request if it is not a streaming request
315+
if (request.streamOptions() != null && !stream) {
316+
logger.warn("Removing streamOptions from the request as it is not a streaming request!");
317+
request = request.withStreamOptions(null);
318+
}
319+
309320
return request;
310321
}
311322

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.model.function.FunctionCallingOptions;
2828
import org.springframework.ai.openai.api.OpenAiApi;
2929
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat;
30+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions;
3031
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder;
3132
import org.springframework.ai.openai.api.OpenAiApi.FunctionTool;
3233
import org.springframework.boot.context.properties.NestedConfigurationProperty;
@@ -39,6 +40,7 @@
3940

4041
/**
4142
* @author Christian Tzolov
43+
* @author Mariusz Bernacki
4244
* @since 0.8.0
4345
*/
4446
@JsonInclude(Include.NON_NULL)
@@ -93,6 +95,10 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
9395
* "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.
9496
*/
9597
private @JsonProperty("response_format") ResponseFormat responseFormat;
98+
/**
99+
* Options for streaming response. Included in the API only if streaming-mode completion is requested.
100+
*/
101+
private @JsonProperty("stream_options") StreamOptions streamOptions;
96102
/**
97103
* This feature is in Beta. If specified, our system will make a best effort to sample
98104
* deterministically, such that repeated requests with the same seed and parameters should return the same result.
@@ -226,6 +232,13 @@ public Builder withResponseFormat(ResponseFormat responseFormat) {
226232
return this;
227233
}
228234

235+
public Builder withStreamUsage(boolean enableStreamUsage) {
236+
if (enableStreamUsage) {
237+
this.options.streamOptions = (enableStreamUsage) ? StreamOptions.INCLUDE_USAGE : null;
238+
}
239+
return this;
240+
}
241+
229242
public Builder withSeed(Integer seed) {
230243
this.options.seed = seed;
231244
return this;
@@ -284,6 +297,14 @@ public OpenAiChatOptions build() {
284297

285298
}
286299

300+
public Boolean getStreamUsage() {
301+
return this.streamOptions != null;
302+
}
303+
304+
public void setStreamUsage(Boolean enableStreamUsage) {
305+
this.streamOptions = (enableStreamUsage) ? StreamOptions.INCLUDE_USAGE : null;
306+
}
307+
287308
public String getModel() {
288309
return this.model;
289310
}
@@ -356,6 +377,14 @@ public void setResponseFormat(ResponseFormat responseFormat) {
356377
this.responseFormat = responseFormat;
357378
}
358379

380+
public StreamOptions getStreamOptions() {
381+
return streamOptions;
382+
}
383+
384+
public void setStreamOptions(StreamOptions streamOptions) {
385+
this.streamOptions = streamOptions;
386+
}
387+
359388
public Integer getSeed() {
360389
return this.seed;
361390
}
@@ -446,6 +475,7 @@ public int hashCode() {
446475
result = prime * result + ((n == null) ? 0 : n.hashCode());
447476
result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode());
448477
result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode());
478+
result = prime * result + ((streamOptions == null) ? 0 : streamOptions.hashCode());
449479
result = prime * result + ((seed == null) ? 0 : seed.hashCode());
450480
result = prime * result + ((stop == null) ? 0 : stop.hashCode());
451481
result = prime * result + ((temperature == null) ? 0 : temperature.hashCode());
@@ -519,6 +549,12 @@ else if (!this.presencePenalty.equals(other.presencePenalty))
519549
}
520550
else if (!this.responseFormat.equals(other.responseFormat))
521551
return false;
552+
if (this.streamOptions == null) {
553+
if (other.streamOptions != null)
554+
return false;
555+
}
556+
else if (!this.streamOptions.equals(other.streamOptions))
557+
return false;
522558
if (this.seed == null) {
523559
if (other.seed != null)
524560
return false;
@@ -586,6 +622,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
586622
.withN(fromOptions.getN())
587623
.withPresencePenalty(fromOptions.getPresencePenalty())
588624
.withResponseFormat(fromOptions.getResponseFormat())
625+
.withStreamUsage(fromOptions.getStreamUsage())
589626
.withSeed(fromOptions.getSeed())
590627
.withStop(fromOptions.getStop())
591628
.withTemperature(fromOptions.getTemperature())

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
*
4848
* @author Christian Tzolov
4949
* @author Michael Lavelle
50+
* @author Mariusz Bernacki
5051
*/
5152
public class OpenAiApi {
5253

@@ -314,6 +315,7 @@ public Function(String description, String name, String jsonSchema) {
314315
* @param stop Up to 4 sequences where the API will stop generating further tokens.
315316
* @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events as
316317
* they become available, with the stream terminated by a data: [DONE] message.
318+
* @param streamOptions Options for streaming response. Only set this when you set.
317319
* @param temperature What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output
318320
* more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend
319321
* altering this or top_p but not both.
@@ -345,6 +347,7 @@ public record ChatCompletionRequest (
345347
@JsonProperty("seed") Integer seed,
346348
@JsonProperty("stop") List<String> stop,
347349
@JsonProperty("stream") Boolean stream,
350+
@JsonProperty("stream_options") StreamOptions streamOptions,
348351
@JsonProperty("temperature") Float temperature,
349352
@JsonProperty("top_p") Float topP,
350353
@JsonProperty("tools") List<FunctionTool> tools,
@@ -360,7 +363,7 @@ public record ChatCompletionRequest (
360363
*/
361364
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Float temperature) {
362365
this(messages, model, null, null, null, null, null, null, null,
363-
null, null, null, false, temperature, null,
366+
null, null, null, false, null, temperature, null,
364367
null, null, null);
365368
}
366369

@@ -375,7 +378,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
375378
*/
376379
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Float temperature, boolean stream) {
377380
this(messages, model, null, null, null, null, null, null, null,
378-
null, null, null, stream, temperature, null,
381+
null, null, null, stream, null, temperature, null,
379382
null, null, null);
380383
}
381384

@@ -391,7 +394,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
391394
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
392395
List<FunctionTool> tools, Object toolChoice) {
393396
this(messages, model, null, null, null, null, null, null, null,
394-
null, null, null, false, 0.8f, null,
397+
null, null, null, false, null, 0.8f, null,
395398
tools, toolChoice, null);
396399
}
397400

@@ -404,10 +407,22 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
404407
*/
405408
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
406409
this(messages, null, null, null, null, null, null, null, null,
407-
null, null, null, stream, null, null,
410+
null, null, null, stream, null, null, null,
408411
null, null, null);
409412
}
410413

414+
/**
415+
* Sets the {@link StreamOptions} for this request.
416+
*
417+
* @param streamOptions The new stream options to use.
418+
* @return A new {@link ChatCompletionRequest} with the specified stream options.
419+
*/
420+
public ChatCompletionRequest withStreamOptions(StreamOptions streamOptions) {
421+
return new ChatCompletionRequest(messages, model, frequencyPenalty, logitBias, logprobs, topLogprobs, maxTokens, n, presencePenalty,
422+
responseFormat, seed, stop, stream, streamOptions, temperature, topP,
423+
tools, toolChoice, user);
424+
}
425+
411426
/**
412427
* Helper factory that creates a tool_choice of type 'none', 'auto' or selected function by name.
413428
*/
@@ -437,6 +452,20 @@ public static Object FUNCTION(String functionName) {
437452
public record ResponseFormat(
438453
@JsonProperty("type") String type) {
439454
}
455+
456+
/**
457+
* @param includeUsage If set, an additional chunk will be streamed
458+
* before the data: [DONE] message. The usage field on this chunk
459+
* shows the token usage statistics for the entire request, and
460+
* the choices field will always be an empty array. All other chunks
461+
* will also include a usage field, but with a null value.
462+
*/
463+
@JsonInclude(Include.NON_NULL)
464+
public record StreamOptions(
465+
@JsonProperty("include_usage") Boolean includeUsage) {
466+
467+
public static StreamOptions INCLUDE_USAGE = new StreamOptions(true);
468+
}
440469
}
441470

442471
/**
@@ -742,7 +771,8 @@ public record ChatCompletionChunk(
742771
@JsonProperty("created") Long created,
743772
@JsonProperty("model") String model,
744773
@JsonProperty("system_fingerprint") String systemFingerprint,
745-
@JsonProperty("object") String object) {
774+
@JsonProperty("object") String object,
775+
@JsonProperty("usage") Usage usage) {
746776

747777
/**
748778
* Chat completion choice.
@@ -825,7 +855,7 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
825855
// Flux<Flux<ChatCompletionChunk>> -> Flux<Mono<ChatCompletionChunk>>
826856
.concatMapIterable(window -> {
827857
Mono<ChatCompletionChunk> monoChunk = window.reduce(
828-
new ChatCompletionChunk(null, null, null, null, null, null),
858+
new ChatCompletionChunk(null, null, null, null, null, null, null),
829859
(previous, current) -> this.chunkMerger.merge(previous, current));
830860
return List.of(monoChunk);
831861
})

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ChatCompletionFunction;
2929
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;
3030
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall;
31+
import org.springframework.ai.openai.api.OpenAiApi.Usage;
3132
import org.springframework.util.CollectionUtils;
3233

3334
/**
@@ -58,13 +59,14 @@ public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChu
5859
String systemFingerprint = (current.systemFingerprint() != null ? current.systemFingerprint()
5960
: previous.systemFingerprint());
6061
String object = (current.object() != null ? current.object() : previous.object());
62+
Usage usage = (current.usage() != null ? current.usage() : previous.usage());
6163

6264
ChunkChoice previousChoice0 = (CollectionUtils.isEmpty(previous.choices()) ? null : previous.choices().get(0));
6365
ChunkChoice currentChoice0 = (CollectionUtils.isEmpty(current.choices()) ? null : current.choices().get(0));
6466

6567
ChunkChoice choice = merge(previousChoice0, currentChoice0);
6668
List<ChunkChoice> chunkChoices = choice == null ? List.of() : List.of(choice);
67-
return new ChatCompletionChunk(id, chunkChoices, created, model, systemFingerprint, object);
69+
return new ChatCompletionChunk(id, chunkChoices, created, model, systemFingerprint, object, usage);
6870
}
6971

7072
private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public class MessageTypeContentTests {
6161
ArgumentCaptor<ChatCompletionRequest> promptCaptor;
6262

6363
Flux<ChatCompletionChunk> fluxResponse = Flux
64-
.generate(() -> new ChatCompletionChunk("id", List.of(), 0l, "model", "fp", "object"), (state, sink) -> {
64+
.generate(() -> new ChatCompletionChunk("id", List.of(), 0l, "model", "fp", "object", null), (state, sink) -> {
6565
sink.next(state);
6666
sink.complete();
6767
return state;

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616
package org.springframework.ai.openai.chat;
1717

18+
import static org.assertj.core.api.Assertions.assertThat;
19+
1820
import java.io.IOException;
1921
import java.net.URL;
2022
import java.util.ArrayList;
@@ -29,14 +31,12 @@
2931
import org.junit.jupiter.params.provider.ValueSource;
3032
import org.slf4j.Logger;
3133
import org.slf4j.LoggerFactory;
32-
import reactor.core.publisher.Flux;
33-
34-
import org.springframework.ai.chat.model.ChatResponse;
35-
import org.springframework.ai.chat.model.Generation;
3634
import org.springframework.ai.chat.messages.AssistantMessage;
3735
import org.springframework.ai.chat.messages.Media;
3836
import org.springframework.ai.chat.messages.Message;
3937
import org.springframework.ai.chat.messages.UserMessage;
38+
import org.springframework.ai.chat.model.ChatResponse;
39+
import org.springframework.ai.chat.model.Generation;
4040
import org.springframework.ai.chat.prompt.Prompt;
4141
import org.springframework.ai.chat.prompt.PromptTemplate;
4242
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
@@ -56,7 +56,7 @@
5656
import org.springframework.core.io.Resource;
5757
import org.springframework.util.MimeTypeUtils;
5858

59-
import static org.assertj.core.api.Assertions.assertThat;
59+
import reactor.core.publisher.Flux;
6060

6161
@SpringBootTest(classes = OpenAiTestConfiguration.class)
6262
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
@@ -103,6 +103,24 @@ void streamRoleTest() {
103103

104104
}
105105

106+
@Test
107+
void streamingWithTokenUsage() {
108+
var promptOptions = OpenAiChatOptions.builder().withStreamUsage(true).withSeed(1).build();
109+
110+
var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions);
111+
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
112+
var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage();
113+
114+
assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0);
115+
assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0);
116+
assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0);
117+
118+
assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens());
119+
assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens());
120+
assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens());
121+
122+
}
123+
106124
@Test
107125
void listOutputConverter() {
108126
DefaultConversionService conversionService = new DefaultConversionService();

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ public void openAiChatStreamTransientError() {
167167
var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0,
168168
new ChatCompletionMessage("Response", Role.ASSISTANT), null);
169169
ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null,
170-
null);
170+
null, null);
171171

172172
when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
173173
.thenThrow(new TransientAiException("Transient Error 1"))

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ The prefix `spring.ai.openai.chat` is the property prefix that lets you configur
105105
| spring.ai.openai.chat.options.toolChoice | Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces the model to call that function. none is the default when no functions are present. auto is the default if functions are present. | -
106106
| spring.ai.openai.chat.options.user | A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. | -
107107
| spring.ai.openai.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | -
108+
| spring.ai.openai.chat.options.stream-usage | (For streaming only) Set to add an additional chunk with token usage statistics for the entire request. The `choices` field for this chuk is an empty array and all other chunks will also include a usage field, but with a null value. | false
108109
|====
109110

110111
NOTE: You can override the common `spring.ai.openai.base-url` and `spring.ai.openai.api-key` for the `ChatModel` and `EmbeddingModel` implementations.

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.junit.jupiter.api.Test;
2727
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2828
import org.springframework.ai.chat.messages.UserMessage;
29+
import org.springframework.ai.chat.metadata.Usage;
2930
import org.springframework.ai.chat.model.ChatResponse;
3031
import org.springframework.ai.chat.prompt.Prompt;
3132
import org.springframework.ai.embedding.EmbeddingResponse;
@@ -113,6 +114,28 @@ void generateStreaming() {
113114
});
114115
}
115116

117+
@Test
118+
void streamingWithTokenUsage() {
119+
contextRunner.withPropertyValues("spring.ai.openai.chat.options.stream-usage=true").run(context -> {
120+
OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class);
121+
122+
Flux<ChatResponse> responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello")));
123+
124+
Usage[] streamingTokenUsage = new Usage[1];
125+
String response = responseFlux.collectList().block().stream().map(chatResponse -> {
126+
streamingTokenUsage[0] = chatResponse.getMetadata().getUsage();
127+
return (chatResponse.getResult() != null) ? chatResponse.getResult().getOutput().getContent() : "";
128+
}).collect(Collectors.joining());
129+
130+
assertThat(streamingTokenUsage[0].getPromptTokens()).isGreaterThan(0);
131+
assertThat(streamingTokenUsage[0].getGenerationTokens()).isGreaterThan(0);
132+
assertThat(streamingTokenUsage[0].getTotalTokens()).isGreaterThan(0);
133+
134+
assertThat(response).isNotEmpty();
135+
logger.info("Response: " + response);
136+
});
137+
}
138+
116139
@Test
117140
void embedding() {
118141
contextRunner.run(context -> {

0 commit comments

Comments
 (0)