Skip to content

Commit 2129547

Browse files
alexcheng1982tzolov
authored andcommitted
Fix Usage in Ollama ChatResponse
The Usage of Ollama ChatResponse was put into ChatGenerationMetadata as content filter metadata. The correct place should be in ChatResponseMetadata
1 parent 9236913 commit 2129547

File tree

4 files changed

+132
-22
lines changed

4 files changed

+132
-22
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.Base64;
1919
import java.util.List;
2020

21+
import org.springframework.ai.ollama.metadata.OllamaChatResponseMetadata;
2122
import reactor.core.publisher.Flux;
2223

2324
import org.springframework.ai.chat.ChatClient;
@@ -27,7 +28,6 @@
2728
import org.springframework.ai.chat.messages.Message;
2829
import org.springframework.ai.chat.messages.MessageType;
2930
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
30-
import org.springframework.ai.chat.metadata.Usage;
3131
import org.springframework.ai.chat.prompt.ChatOptions;
3232
import org.springframework.ai.chat.prompt.Prompt;
3333
import org.springframework.ai.model.ModelOptionsUtils;
@@ -99,10 +99,9 @@ public ChatResponse call(Prompt prompt) {
9999

100100
var generator = new Generation(response.message().content());
101101
if (response.promptEvalCount() != null && response.evalCount() != null) {
102-
generator = generator
103-
.withGenerationMetadata(ChatGenerationMetadata.from("unknown", extractUsage(response)));
102+
generator = generator.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null));
104103
}
105-
return new ChatResponse(List.of(generator));
104+
return new ChatResponse(List.of(generator), OllamaChatResponseMetadata.from(response));
106105
}
107106

108107
@Override
@@ -114,28 +113,12 @@ public Flux<ChatResponse> stream(Prompt prompt) {
114113
Generation generation = (chunk.message() != null) ? new Generation(chunk.message().content())
115114
: new Generation("");
116115
if (Boolean.TRUE.equals(chunk.done())) {
117-
generation = generation
118-
.withGenerationMetadata(ChatGenerationMetadata.from("unknown", extractUsage(chunk)));
116+
generation = generation.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null));
119117
}
120-
return new ChatResponse(List.of(generation));
118+
return new ChatResponse(List.of(generation), OllamaChatResponseMetadata.from(chunk));
121119
});
122120
}
123121

124-
private Usage extractUsage(OllamaApi.ChatResponse response) {
125-
return new Usage() {
126-
127-
@Override
128-
public Long getPromptTokens() {
129-
return response.promptEvalCount().longValue();
130-
}
131-
132-
@Override
133-
public Long getGenerationTokens() {
134-
return response.evalCount().longValue();
135-
}
136-
};
137-
}
138-
139122
/**
140123
* Package access for testing.
141124
*/
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.ollama.metadata;
17+
18+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
19+
import org.springframework.ai.chat.metadata.Usage;
20+
import org.springframework.ai.ollama.api.OllamaApi;
21+
import org.springframework.util.Assert;
22+
23+
/**
24+
* {@link ChatResponseMetadata} implementation for {@literal Ollama}
25+
*
26+
* @see ChatResponseMetadata
27+
* @author Fu Cheng
28+
*/
29+
public class OllamaChatResponseMetadata implements ChatResponseMetadata {
30+
31+
protected static final String AI_METADATA_STRING = "{ @type: %1$s, usage: %2$s, rateLimit: %3$s }";
32+
33+
public static OllamaChatResponseMetadata from(OllamaApi.ChatResponse response) {
34+
Assert.notNull(response, "OllamaApi.ChatResponse must not be null");
35+
Usage usage = OllamaUsage.from(response);
36+
return new OllamaChatResponseMetadata(usage);
37+
}
38+
39+
private final Usage usage;
40+
41+
protected OllamaChatResponseMetadata(Usage usage) {
42+
this.usage = usage;
43+
}
44+
45+
@Override
46+
public Usage getUsage() {
47+
return this.usage;
48+
}
49+
50+
@Override
51+
public String toString() {
52+
return AI_METADATA_STRING.formatted(getClass().getTypeName(), getUsage(), getRateLimit());
53+
}
54+
55+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.ollama.metadata;
17+
18+
import java.util.Optional;
19+
import org.springframework.ai.chat.metadata.Usage;
20+
import org.springframework.ai.ollama.api.OllamaApi;
21+
import org.springframework.util.Assert;
22+
23+
/**
24+
* {@link Usage} implementation for {@literal Ollama}
25+
*
26+
* @see Usage
27+
* @author Fu Cheng
28+
*/
29+
public class OllamaUsage implements Usage {
30+
31+
protected static final String AI_USAGE_STRING = "{ promptTokens: %1$d, generationTokens: %2$d, totalTokens: %3$d }";
32+
33+
public static OllamaUsage from(OllamaApi.ChatResponse response) {
34+
Assert.notNull(response, "OllamaApi.ChatResponse must not be null");
35+
return new OllamaUsage(response);
36+
}
37+
38+
private final OllamaApi.ChatResponse response;
39+
40+
public OllamaUsage(OllamaApi.ChatResponse response) {
41+
this.response = response;
42+
}
43+
44+
@Override
45+
public Long getPromptTokens() {
46+
return Optional.ofNullable(response.promptEvalCount()).map(Integer::longValue).orElse(0L);
47+
}
48+
49+
@Override
50+
public Long getGenerationTokens() {
51+
return Optional.ofNullable(response.evalCount()).map(Integer::longValue).orElse(0L);
52+
}
53+
54+
@Override
55+
public String toString() {
56+
return AI_USAGE_STRING.formatted(getPromptTokens(), getGenerationTokens(), getTotalTokens());
57+
}
58+
59+
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatClientIT.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.junit.jupiter.api.BeforeAll;
2727
import org.junit.jupiter.api.Disabled;
2828
import org.junit.jupiter.api.Test;
29+
import org.springframework.ai.chat.metadata.Usage;
2930
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
3031
import org.springframework.ai.chat.messages.AssistantMessage;
3132
import org.testcontainers.containers.GenericContainer;
@@ -105,6 +106,18 @@ void roleTest() {
105106

106107
}
107108

109+
@Test
110+
void usageTest() {
111+
Prompt prompt = new Prompt("Tell me a joke");
112+
ChatResponse response = client.call(prompt);
113+
Usage usage = response.getMetadata().getUsage();
114+
115+
assertThat(usage).isNotNull();
116+
assertThat(usage.getPromptTokens()).isPositive();
117+
assertThat(usage.getGenerationTokens()).isPositive();
118+
assertThat(usage.getTotalTokens()).isPositive();
119+
}
120+
108121
@Test
109122
void outputParser() {
110123
DefaultConversionService conversionService = new DefaultConversionService();

0 commit comments

Comments
 (0)