Skip to content

Commit 6d38c85

Browse files
committed
Refactor Ollama usage metadata to add embedding support
- Extend the OllamaApi.EmbeddingsResponse with total_duration, load_duration and prompt_eval_count fields - Rename OllamaUsage to OllamaChatUsage for clarity - Add OllamaEmbeddingUsage to track embedding-specific usage metrics - Update OllamaEmbeddingModel to use OllamaEmbeddingUsage - Extend EmbeddingsResponse with additional metadata fields - Update tests to reflect new usage tracking for embeddings Resolves #1536
1 parent fe09b4d commit 6d38c85

File tree

8 files changed

+96
-29
lines changed

8 files changed

+96
-29
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall;
4949
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction;
5050
import org.springframework.ai.ollama.api.OllamaOptions;
51-
import org.springframework.ai.ollama.metadata.OllamaUsage;
51+
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
5252
import org.springframework.util.Assert;
5353
import org.springframework.util.CollectionUtils;
5454
import org.springframework.util.StringUtils;
@@ -177,7 +177,7 @@ && isToolCall(response, Set.of("stop"))) {
177177
public static ChatResponseMetadata from(OllamaApi.ChatResponse response) {
178178
Assert.notNull(response, "OllamaApi.ChatResponse must not be null");
179179
return ChatResponseMetadata.builder()
180-
.withUsage(OllamaUsage.from(response))
180+
.withUsage(OllamaChatUsage.from(response))
181181
.withModel(response.model())
182182
.withKeyValue("created-at", response.createdAt())
183183
.withKeyValue("eval-duration", response.evalDuration())

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.ai.ollama.api.OllamaApi;
3434
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
3535
import org.springframework.ai.ollama.api.OllamaOptions;
36+
import org.springframework.ai.ollama.metadata.OllamaEmbeddingUsage;
3637
import org.springframework.util.Assert;
3738
import org.springframework.util.StringUtils;
3839

@@ -125,7 +126,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
125126
.toList();
126127

127128
EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata(response.model(),
128-
new EmptyUsage());
129+
OllamaEmbeddingUsage.from(response));
129130

130131
EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, embeddingResponseMetadata);
131132

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,11 @@ public record EmbeddingResponse(
767767
@JsonInclude(Include.NON_NULL)
768768
public record EmbeddingsResponse(
769769
@JsonProperty("model") String model,
770-
@JsonProperty("embeddings") List<float[]> embeddings) {
770+
@JsonProperty("embeddings") List<float[]> embeddings,
771+
@JsonProperty("total_duration") Long totalDuration,
772+
@JsonProperty("load_duration") Long loadDuration,
773+
@JsonProperty("prompt_eval_count") Integer promptEvalCount) {
774+
771775
}
772776

773777
/**

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaUsage.java renamed to models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/metadata/OllamaChatUsage.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,18 @@
2626
* @see Usage
2727
* @author Fu Cheng
2828
*/
29-
public class OllamaUsage implements Usage {
29+
public class OllamaChatUsage implements Usage {
3030

3131
protected static final String AI_USAGE_STRING = "{ promptTokens: %1$d, generationTokens: %2$d, totalTokens: %3$d }";
3232

33-
public static OllamaUsage from(OllamaApi.ChatResponse response) {
33+
public static OllamaChatUsage from(OllamaApi.ChatResponse response) {
3434
Assert.notNull(response, "OllamaApi.ChatResponse must not be null");
35-
return new OllamaUsage(response);
35+
return new OllamaChatUsage(response);
3636
}
3737

3838
private final OllamaApi.ChatResponse response;
3939

40-
public OllamaUsage(OllamaApi.ChatResponse response) {
40+
public OllamaChatUsage(OllamaApi.ChatResponse response) {
4141
this.response = response;
4242
}
4343

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.ollama.metadata;
17+
18+
import java.util.Optional;
19+
20+
import org.springframework.ai.chat.metadata.Usage;
21+
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
22+
import org.springframework.util.Assert;
23+
24+
/**
25+
* {@link Usage} implementation for {@literal Ollama} embeddings.
26+
*
27+
* @see Usage
28+
* @author Christian Tzolov
29+
*/
30+
public class OllamaEmbeddingUsage implements Usage {
31+
32+
protected static final String AI_USAGE_STRING = "{ promptTokens: %1$d, generationTokens: %2$d, totalTokens: %3$d }";
33+
34+
public static OllamaEmbeddingUsage from(EmbeddingsResponse response) {
35+
Assert.notNull(response, "OllamaApi.EmbeddingsResponse must not be null");
36+
return new OllamaEmbeddingUsage(response);
37+
}
38+
39+
private Long promptTokens;
40+
41+
public OllamaEmbeddingUsage(EmbeddingsResponse response) {
42+
this.promptTokens = Optional.ofNullable(response.promptEvalCount()).map(Integer::longValue).orElse(0L);
43+
}
44+
45+
@Override
46+
public Long getPromptTokens() {
47+
return this.promptTokens;
48+
}
49+
50+
@Override
51+
public Long getGenerationTokens() {
52+
return 0L;
53+
}
54+
55+
@Override
56+
public String toString() {
57+
return AI_USAGE_STRING.formatted(getPromptTokens(), getGenerationTokens(), getTotalTokens());
58+
}
59+
60+
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelIT.java

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
*/
1616
package org.springframework.ai.ollama;
1717

18+
import static org.assertj.core.api.Assertions.assertThat;
19+
20+
import java.io.IOException;
21+
import java.util.List;
22+
1823
import org.apache.commons.logging.Log;
1924
import org.apache.commons.logging.LogFactory;
2025
import org.junit.jupiter.api.BeforeAll;
@@ -24,32 +29,22 @@
2429
import org.springframework.ai.embedding.EmbeddingResponse;
2530
import org.springframework.ai.ollama.api.OllamaApi;
2631
import org.springframework.ai.ollama.api.OllamaApiIT;
27-
import org.springframework.ai.ollama.api.OllamaModel;
2832
import org.springframework.ai.ollama.api.OllamaOptions;
2933
import org.springframework.beans.factory.annotation.Autowired;
3034
import org.springframework.boot.SpringBootConfiguration;
3135
import org.springframework.boot.test.context.SpringBootTest;
3236
import org.springframework.context.annotation.Bean;
3337
import org.testcontainers.junit.jupiter.Testcontainers;
3438

35-
import java.io.IOException;
36-
import java.util.List;
37-
38-
import static org.assertj.core.api.Assertions.assertThat;
39-
4039
@SpringBootTest
4140
@DisabledIf("isDisabled")
4241
@Testcontainers
4342
class OllamaEmbeddingModelIT extends BaseOllamaIT {
4443

45-
private static final String MODEL = OllamaModel.MISTRAL.getName();
44+
private static final String MODEL = "mxbai-embed-large";
4645

4746
private static final Log logger = LogFactory.getLog(OllamaApiIT.class);
4847

49-
// @Container
50-
// static OllamaContainer ollamaContainer = new
51-
// OllamaContainer(OllamaImage.DEFAULT_IMAGE);
52-
5348
static String baseUrl = "http://localhost:11434";
5449

5550
@BeforeAll
@@ -75,8 +70,10 @@ void embeddings() {
7570
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
7671
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
7772
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(MODEL);
73+
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4);
74+
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4);
7875

79-
assertThat(embeddingModel.dimensions()).isEqualTo(4096);
76+
assertThat(embeddingModel.dimensions()).isEqualTo(1024);
8077
}
8178

8279
@SpringBootConfiguration

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingModelTests.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,25 +54,25 @@ public class OllamaEmbeddingModelTests {
5454
public void options() {
5555

5656
when(ollamaApi.embed(embeddingsRequestCaptor.capture()))
57-
.thenReturn(
58-
new EmbeddingsResponse("RESPONSE_MODEL_NAME", List.of(new float[]{1f, 2f, 3f}, new float[]{4f, 5f, 6f})))
57+
.thenReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME",
58+
List.of(new float[] { 1f, 2f, 3f }, new float[] { 4f, 5f, 6f }), 0L, 0L, 0))
5959
.thenReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME2",
60-
List.of(new float[]{7f, 8f, 9f}, new float[]{10f, 11f, 12f})));
60+
List.of(new float[] { 7f, 8f, 9f }, new float[] { 10f, 11f, 12f }), 0L, 0L, 0));
6161

6262
// Tests default options
6363
var defaultOptions = OllamaOptions.builder().withModel("DEFAULT_MODEL").build();
6464

6565
var embeddingModel = new OllamaEmbeddingModel(ollamaApi, defaultOptions);
6666

67-
EmbeddingResponse response = embeddingModel
68-
.call(new EmbeddingRequest(List.of("Input1", "Input2", "Input3"), EmbeddingOptionsBuilder.builder().build()));
67+
EmbeddingResponse response = embeddingModel.call(
68+
new EmbeddingRequest(List.of("Input1", "Input2", "Input3"), EmbeddingOptionsBuilder.builder().build()));
6969

7070
assertThat(response.getResults()).hasSize(2);
7171
assertThat(response.getResults().get(0).getIndex()).isEqualTo(0);
72-
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[]{1f, 2f, 3f});
72+
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[] { 1f, 2f, 3f });
7373
assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
7474
assertThat(response.getResults().get(1).getIndex()).isEqualTo(1);
75-
assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[]{4f, 5f, 6f});
75+
assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[] { 4f, 5f, 6f });
7676
assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
7777
assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME");
7878

@@ -94,10 +94,10 @@ public void options() {
9494

9595
assertThat(response.getResults()).hasSize(2);
9696
assertThat(response.getResults().get(0).getIndex()).isEqualTo(0);
97-
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[]{7f, 8f, 9f});
97+
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[] { 7f, 8f, 9f });
9898
assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
9999
assertThat(response.getResults().get(1).getIndex()).isEqualTo(1);
100-
assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[]{10f, 11f, 12f});
100+
assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[] { 10f, 11f, 12f });
101101
assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
102102
assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME2");
103103

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/api/OllamaApiIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ public void embedText() {
146146
assertThat(response).isNotNull();
147147
assertThat(response.embeddings()).hasSize(1);
148148
assertThat(response.embeddings().get(0)).hasSize(3200);
149+
assertThat(response.model()).isEqualTo(MODEL);
150+
assertThat(response.promptEvalCount()).isEqualTo(5);
151+
assertThat(response.loadDuration()).isGreaterThan(1);
152+
assertThat(response.totalDuration()).isGreaterThan(1);
153+
149154
}
150155

151156
}

0 commit comments

Comments
 (0)