Skip to content

Commit 3227311

Browse files
ThomasVitaletzolov
authored andcommitted
Introduce id and model in ChatResponseMetadata
* Extend ChatResponseMetadata for Anthropic (blocking, streaming) * Add ChatResponseMetadata for Mistral AI (blocking) * Extend ChatResponseMetadata for OpenAI (blocking) * Deprecate gpt-4-vision-preview and replace its usage in tests because OpenAI rejects the calls (see: https://platform.openai.com/docs/deprecations) Fixes gh-936 Signed-off-by: Thomas Vitale <[email protected]>
1 parent 25a0372 commit 3227311

File tree

17 files changed

+267
-31
lines changed

17 files changed

+267
-31
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/metadata/AnthropicChatResponseMetadata.java

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,42 +30,53 @@
3030
* {@link ChatResponseMetadata} implementation for {@literal AnthropicApi}.
3131
*
3232
* @author Christian Tzolov
33+
* @author Thomas Vitale
3334
* @see ChatResponseMetadata
3435
* @see RateLimit
3536
* @see Usage
3637
* @since 1.0.0
3738
*/
3839
public class AnthropicChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
3940

40-
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }";
41+
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s, rateLimit: %5$s }";
4142

4243
public static AnthropicChatResponseMetadata from(AnthropicApi.ChatCompletion result) {
4344
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
4445
AnthropicUsage usage = AnthropicUsage.from(result.usage());
45-
return new AnthropicChatResponseMetadata(result.id(), usage);
46+
return new AnthropicChatResponseMetadata(result.id(), result.model(), usage);
4647
}
4748

4849
private final String id;
4950

51+
private final String model;
52+
5053
@Nullable
5154
private RateLimit rateLimit;
5255

5356
private final Usage usage;
5457

55-
protected AnthropicChatResponseMetadata(String id, AnthropicUsage usage) {
56-
this(id, usage, null);
58+
protected AnthropicChatResponseMetadata(String id, String model, AnthropicUsage usage) {
59+
this(id, model, usage, null);
5760
}
5861

59-
protected AnthropicChatResponseMetadata(String id, AnthropicUsage usage, @Nullable AnthropicRateLimit rateLimit) {
62+
protected AnthropicChatResponseMetadata(String id, String model, AnthropicUsage usage,
63+
@Nullable AnthropicRateLimit rateLimit) {
6064
this.id = id;
65+
this.model = model;
6166
this.usage = usage;
6267
this.rateLimit = rateLimit;
6368
}
6469

70+
@Override
6571
public String getId() {
6672
return this.id;
6773
}
6874

75+
@Override
76+
public String getModel() {
77+
return this.model;
78+
}
79+
6980
@Override
7081
@Nullable
7182
public RateLimit getRateLimit() {
@@ -86,7 +97,7 @@ public AnthropicChatResponseMetadata withRateLimit(RateLimit rateLimit) {
8697

8798
@Override
8899
public String toString() {
89-
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getUsage(), getRateLimit());
100+
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage(), getRateLimit());
90101
}
91102

92103
}

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import org.springframework.ai.anthropic.api.AnthropicApi;
3333
import org.springframework.ai.anthropic.api.tool.MockWeatherService;
34+
import org.springframework.ai.chat.client.ChatClient;
3435
import org.springframework.ai.chat.model.ChatModel;
3536
import org.springframework.ai.chat.model.ChatResponse;
3637
import org.springframework.ai.chat.model.Generation;
@@ -241,4 +242,43 @@ void functionCallTest() {
241242
assertThat(generation.getOutput().getContent()).contains("30", "10", "15");
242243
}
243244

245+
@Test
246+
void validateCallResponseMetadata() {
247+
String model = AnthropicApi.ChatModel.CLAUDE_2_1.getModelName();
248+
// @formatter:off
249+
ChatResponse response = ChatClient.create(chatModel).prompt()
250+
.options(AnthropicChatOptions.builder().withModel(model).build())
251+
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
252+
.call()
253+
.chatResponse();
254+
// @formatter:on
255+
256+
logger.info(response.toString());
257+
validateChatResponseMetadata(response, model);
258+
}
259+
260+
@Test
261+
void validateStreamCallResponseMetadata() {
262+
String model = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getModelName();
263+
// @formatter:off
264+
ChatResponse response = ChatClient.create(chatModel).prompt()
265+
.options(AnthropicChatOptions.builder().withModel(model).build())
266+
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
267+
.stream()
268+
.chatResponse()
269+
.blockLast();
270+
// @formatter:on
271+
272+
logger.info(response.toString());
273+
validateChatResponseMetadata(response, model);
274+
}
275+
276+
private static void validateChatResponseMetadata(ChatResponse response, String model) {
277+
assertThat(response.getMetadata().getId()).isNotEmpty();
278+
assertThat(response.getMetadata().getModel()).containsIgnoringCase(model);
279+
assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive();
280+
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive();
281+
assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive();
282+
}
283+
244284
}

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiChatResponseMetadata.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
* {@literal Microsoft Azure OpenAI Service}.
3030
*
3131
* @author John Blum
32+
* @author Thomas Vitale
3233
* @see ChatResponseMetadata
3334
* @since 0.7.1
3435
*/
@@ -59,6 +60,7 @@ protected AzureOpenAiChatResponseMetadata(String id, AzureOpenAiUsage usage, Pro
5960
this.promptMetadata = promptMetadata;
6061
}
6162

63+
@Override
6264
public String getId() {
6365
return this.id;
6466
}

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage;
3131
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall;
3232
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
33+
import org.springframework.ai.mistralai.metadata.MistralAiChatResponseMetadata;
3334
import org.springframework.ai.model.ModelOptionsUtils;
3435
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
3536
import org.springframework.ai.model.function.FunctionCallbackContext;
@@ -119,7 +120,7 @@ public ChatResponse call(Prompt prompt) {
119120
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
120121
.toList();
121122

122-
return new ChatResponse(generations);
123+
return new ChatResponse(generations, MistralAiChatResponseMetadata.from(chatCompletion));
123124
});
124125
}
125126

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package org.springframework.ai.mistralai.metadata;
2+
3+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
4+
import org.springframework.ai.chat.metadata.EmptyUsage;
5+
import org.springframework.ai.chat.metadata.Usage;
6+
import org.springframework.ai.mistralai.api.MistralAiApi;
7+
import org.springframework.util.Assert;
8+
9+
import java.util.HashMap;
10+
11+
/**
12+
* {@link ChatResponseMetadata} implementation for {@literal Mistral AI}.
13+
*
14+
* @author Thomas Vitale
15+
* @see ChatResponseMetadata
16+
* @see Usage
17+
* @since 1.0.0
18+
*/
19+
public class MistralAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
20+
21+
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s }";
22+
23+
public static MistralAiChatResponseMetadata from(MistralAiApi.ChatCompletion result) {
24+
Assert.notNull(result, "Mistral AI ChatCompletion must not be null");
25+
MistralAiUsage usage = MistralAiUsage.from(result.usage());
26+
return new MistralAiChatResponseMetadata(result.id(), result.model(), usage);
27+
}
28+
29+
private final String id;
30+
31+
private final String model;
32+
33+
private final Usage usage;
34+
35+
protected MistralAiChatResponseMetadata(String id, String model, MistralAiUsage usage) {
36+
this.id = id;
37+
this.model = model;
38+
this.usage = usage;
39+
}
40+
41+
@Override
42+
public String getId() {
43+
return this.id;
44+
}
45+
46+
@Override
47+
public String getModel() {
48+
return this.model;
49+
}
50+
51+
@Override
52+
public Usage getUsage() {
53+
Usage usage = this.usage;
54+
return usage != null ? usage : new EmptyUsage();
55+
}
56+
57+
@Override
58+
public String toString() {
59+
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage());
60+
}
61+
62+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package org.springframework.ai.mistralai.metadata;
2+
3+
import org.springframework.ai.chat.metadata.Usage;
4+
import org.springframework.ai.mistralai.api.MistralAiApi;
5+
import org.springframework.util.Assert;
6+
7+
/**
8+
* {@link Usage} implementation for {@literal Mistral AI}.
9+
*
10+
* @author Thomas Vitale
11+
* @since 1.0.0
12+
* @see <a href="https://docs.mistral.ai/api/">Chat Completion API</a>
13+
*/
14+
public class MistralAiUsage implements Usage {
15+
16+
public static MistralAiUsage from(MistralAiApi.Usage usage) {
17+
return new MistralAiUsage(usage);
18+
}
19+
20+
private final MistralAiApi.Usage usage;
21+
22+
protected MistralAiUsage(MistralAiApi.Usage usage) {
23+
Assert.notNull(usage, "Mistral AI Usage must not be null");
24+
this.usage = usage;
25+
}
26+
27+
protected MistralAiApi.Usage getUsage() {
28+
return this.usage;
29+
}
30+
31+
@Override
32+
public Long getPromptTokens() {
33+
return getUsage().promptTokens().longValue();
34+
}
35+
36+
@Override
37+
public Long getGenerationTokens() {
38+
return getUsage().completionTokens().longValue();
39+
}
40+
41+
@Override
42+
public Long getTotalTokens() {
43+
return getUsage().totalTokens().longValue();
44+
}
45+
46+
@Override
47+
public String toString() {
48+
return getUsage().toString();
49+
}
50+
51+
}

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,4 +249,23 @@ void streamFunctionCallTest() {
249249
assertThat(content).containsAnyOf("15.0", "15");
250250
}
251251

252+
@Test
253+
void validateCallResponseMetadata() {
254+
String model = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getModelName();
255+
// @formatter:off
256+
ChatResponse response = ChatClient.create(chatModel).prompt()
257+
.options(MistralAiChatOptions.builder().withModel(model).build())
258+
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
259+
.call()
260+
.chatResponse();
261+
// @formatter:on
262+
263+
logger.info(response.toString());
264+
assertThat(response.getMetadata().getId()).isNotEmpty();
265+
assertThat(response.getMetadata().getModel()).containsIgnoringCase(model);
266+
assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive();
267+
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive();
268+
assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive();
269+
}
270+
252271
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
* @author Hyunjoon Choi
7373
* @author Mariusz Bernacki
7474
* @author luocongqiu
75+
* @author Thomas Vitale
7576
* @see ChatModel
7677
* @see StreamingChatModel
7778
* @see OpenAiApi

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
/**
4242
* @author Christian Tzolov
4343
* @author Mariusz Bernacki
44+
* @author Thomas Vitale
4445
* @since 0.8.0
4546
*/
4647
@JsonInclude(Include.NON_NULL)
@@ -66,8 +67,7 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
6667
private @JsonProperty("logit_bias") Map<String, Integer> logitBias;
6768
/**
6869
* Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities
69-
* of each output token returned in the 'content' of 'message'. This option is currently not available
70-
* on the 'gpt-4-vision-preview' model.
70+
* of each output token returned in the 'content' of 'message'.
7171
*/
7272
private @JsonProperty("logprobs") Boolean logprobs;
7373
/**

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
* @author Christian Tzolov
4949
* @author Michael Lavelle
5050
* @author Mariusz Bernacki
51+
* @author Thomas Vitale
5152
*/
5253
public class OpenAiApi {
5354

@@ -124,7 +125,6 @@ public enum ChatModel implements ModelDescription {
124125
*/
125126
GPT_4_O("gpt-4o"),
126127

127-
128128
/**
129129
* GPT-4 Turbo with Vision
130130
* The latest GPT-4 Turbo model with vision capabilities.
@@ -134,7 +134,7 @@ public enum ChatModel implements ModelDescription {
134134
GPT_4_TURBO("gpt-4-turbo"),
135135

136136
/**
137-
* GPT-4 Turbo with Vision model. Vision requests can now use JSON mode and function calling
137+
* GPT-4 Turbo with Vision model. Vision requests can now use JSON mode and function calling.
138138
*/
139139
GPT_4_TURBO_2204_04_09("gpt-4-turbo-2024-04-09"),
140140

@@ -162,6 +162,7 @@ public enum ChatModel implements ModelDescription {
162162
* Returns a maximum of 4,096 output tokens
163163
* Context window: 128k tokens
164164
*/
165+
@Deprecated(since = "1.0.0-M2", forRemoval = true) // Replaced by GPT_4_O
165166
GPT_4_VISION_PREVIEW("gpt-4-vision-preview"),
166167

167168
/**
@@ -178,6 +179,7 @@ public enum ChatModel implements ModelDescription {
178179
* function calling support.
179180
* Context window: 32k tokens
180181
*/
182+
@Deprecated(since = "1.0.0-M2", forRemoval = true) // Replaced by GPT_4_O
181183
GPT_4_32K("gpt-4-32k"),
182184

183185
/**
@@ -296,8 +298,7 @@ public Function(String description, String name, String jsonSchema) {
296298
* vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100
297299
* or 100 should result in a ban or exclusive selection of the relevant token.
298300
* @param logprobs Whether to return log probabilities of the output tokens or not. If true, returns the log
299-
* probabilities of each output token returned in the 'content' of 'message'. This option is currently not available
300-
* on the 'gpt-4-vision-preview' model.
301+
* probabilities of each output token returned in the 'content' of 'message'.
301302
* @param topLogprobs An integer between 0 and 5 specifying the number of most likely tokens to return at each token
302303
* position, each with an associated log probability. 'logprobs' must be set to 'true' if this parameter is used.
303304
* @param maxTokens The maximum number of tokens to generate in the chat completion. The total length of input

0 commit comments

Comments
 (0)