diff --git a/pom.xml b/pom.xml index cc5c8607948..53b7f671ed4 100644 --- a/pom.xml +++ b/pom.xml @@ -79,7 +79,7 @@ 17 17 - + 3.1.3 4.0.2 0.16.0 @@ -87,7 +87,7 @@ 0.6.1 4.31.1 - + 3.0.0 0.1.3 42.6.0 @@ -97,7 +97,7 @@ 2.0.42 11.6.0 - + 1.19.0 @@ -409,4 +409,4 @@ - \ No newline at end of file + diff --git a/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/client/AzureOpenAiClient.java b/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/client/AzureOpenAiClient.java index 6e4f7569919..955dccabf28 100644 --- a/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/client/AzureOpenAiClient.java +++ b/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/client/AzureOpenAiClient.java @@ -24,9 +24,12 @@ import com.azure.ai.openai.models.ChatRole; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.azure.openai.metadata.AzureOpenAiGenerationMetadata; import org.springframework.ai.client.AiClient; import org.springframework.ai.client.AiResponse; import org.springframework.ai.client.Generation; +import org.springframework.ai.metadata.PromptMetadata; +import org.springframework.ai.metadata.PromptMetadata.PromptFilterMetadata; import org.springframework.ai.prompt.Prompt; import org.springframework.ai.prompt.messages.Message; import org.springframework.util.Assert; @@ -124,13 +127,22 @@ public AiResponse generate(Prompt prompt) { for (ChatChoice choice : chatCompletions.getChoices()) { ChatMessage choiceMessage = choice.getMessage(); - // TODO investigate mapping of additional metadata/runtime info to the general - // model. Generation generation = new Generation(choiceMessage.getContent()); generations.add(generation); } - return new AiResponse(generations); + return new AiResponse(generations, AzureOpenAiGenerationMetadata.from(chatCompletions)) + .withPromptMetadata(generatePromptMetadata(chatCompletions)); + } + + private PromptMetadata generatePromptMetadata(ChatCompletions chatCompletions) { + + return PromptMetadata.of(chatCompletions.getPromptFilterResults() + .stream() + .map(promptFilterResult -> PromptFilterMetadata.from(promptFilterResult.getPromptIndex(), + promptFilterResult.getContentFilterResults())) + .toList()); + } } diff --git a/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiGenerationMetadata.java b/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiGenerationMetadata.java new file mode 100644 index 00000000000..51208526da8 --- /dev/null +++ b/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiGenerationMetadata.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.azure.openai.metadata; + +import com.azure.ai.openai.models.ChatCompletions; + +import org.springframework.ai.metadata.GenerationMetadata; +import org.springframework.ai.metadata.Usage; +import org.springframework.util.Assert; + +/** + * {@link GenerationMetadata} implementation for + * {@literal Microsoft Azure OpenAI Service}. + * + * @author John Blum + * @see org.springframework.ai.metadata.GenerationMetadata + * @since 0.7.1 + */ +public class AzureOpenAiGenerationMetadata implements GenerationMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }"; + + @SuppressWarnings("all") + public static AzureOpenAiGenerationMetadata from(ChatCompletions chatCompletions) { + Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null"); + String id = chatCompletions.getId(); + AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions); + AzureOpenAiGenerationMetadata generationMetadata = new AzureOpenAiGenerationMetadata(id, usage); + return generationMetadata; + } + + private final String id; + + private final Usage usage; + + protected AzureOpenAiGenerationMetadata(String id, AzureOpenAiUsage usage) { + this.id = id; + this.usage = usage; + } + + public String getId() { + return this.id; + } + + @Override + public Usage getUsage() { + return this.usage; + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getUsage(), getRateLimit()); + } + +} diff --git a/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java b/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java new file mode 100644 index 00000000000..928d4b71a92 --- /dev/null +++ b/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/metadata/AzureOpenAiUsage.java @@ -0,0 +1,74 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.azure.openai.metadata; + +import com.azure.ai.openai.models.ChatCompletions; +import com.azure.ai.openai.models.CompletionsUsage; + +import org.springframework.ai.metadata.Usage; +import org.springframework.util.Assert; + +/** + * {@link Usage} implementation for {@literal Microsoft Azure OpenAI Service}. + * + * @author John Blum + * @see com.azure.ai.openai.models.CompletionsUsage + * @since 0.7.0 + */ +public class AzureOpenAiUsage implements Usage { + + public static AzureOpenAiUsage from(ChatCompletions chatCompletions) { + Assert.notNull(chatCompletions, "ChatCompletions must not be null"); + return from(chatCompletions.getUsage()); + } + + public static AzureOpenAiUsage from(CompletionsUsage usage) { + return new AzureOpenAiUsage(usage); + } + + private final CompletionsUsage usage; + + public AzureOpenAiUsage(CompletionsUsage usage) { + Assert.notNull(usage, "CompletionUsage must not be null"); + this.usage = usage; + } + + protected CompletionsUsage getUsage() { + return this.usage; + } + + @Override + public Long getPromptTokens() { + return Integer.valueOf(getUsage().getPromptTokens()).longValue(); + } + + @Override + public Long getGenerationTokens() { + return Integer.valueOf(getUsage().getCompletionTokens()).longValue(); + } + + @Override + public Long getTotalTokens() { + return Integer.valueOf(getUsage().getTotalTokens()).longValue(); + } + + @Override + public String toString() { + return getUsage().toString(); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/client/AiResponse.java b/spring-ai-core/src/main/java/org/springframework/ai/client/AiResponse.java index 3e65c6fd644..9cd74ce9e24 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/client/AiResponse.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/client/AiResponse.java @@ -19,53 +19,108 @@ import java.util.List; import java.util.Map; +import org.springframework.ai.metadata.GenerationMetadata; +import org.springframework.ai.metadata.PromptMetadata; +import org.springframework.lang.Nullable; + public class AiResponse { + private final GenerationMetadata metadata; + private final List generations; - private Map providerOutput; + private final Map providerOutput; + + private final Map runInfo; - private Map runInfo; + private PromptMetadata promptMetadata; public AiResponse(List generations) { - this(generations, Collections.emptyMap(), Collections.emptyMap()); + this(generations, Collections.emptyMap(), Collections.emptyMap(), GenerationMetadata.NULL); + } + + public AiResponse(List generations, GenerationMetadata metadata) { + this(generations, Collections.emptyMap(), Collections.emptyMap(), metadata); } public AiResponse(List generations, Map providerOutput) { - this(generations, providerOutput, Collections.emptyMap()); + this(generations, providerOutput, Collections.emptyMap(), GenerationMetadata.NULL); } public AiResponse(List generations, Map providerOutput, Map runInfo) { + this(generations, providerOutput, runInfo, GenerationMetadata.NULL); + } + + public AiResponse(List generations, Map providerOutput, Map runInfo, + GenerationMetadata metadata) { + + this.metadata = metadata; this.generations = List.copyOf(generations); this.providerOutput = Map.copyOf(providerOutput); this.runInfo = Map.copyOf(runInfo); } /** - * The list of generated outputs. It is a list of lists because the Prompt could - * request multiple output generations. - * @return + * The {@link List} of {@link Generation generated outputs}. + *

+ * It is a {@link List} of {@link List lists} because the Prompt could request + * multiple output {@link Generation generations}. + * @return the {@link List} of {@link Generation generated outputs}. */ public List getGenerations() { - return Collections.unmodifiableList(generations); + return this.generations; } public Generation getGeneration() { return this.generations.get(0); } + /** + * Returns {@link GenerationMetadata} containing information about the use of the AI + * provider's API. + * @return {@link GenerationMetadata} containing information about the use of the AI + * provider's API. + */ + public GenerationMetadata getGenerationMetadata() { + return this.metadata; + } + + /** + * Returns {@link PromptMetadata} containing information on prompt processing by the + * AI. + * @return {@link PromptMetadata} containing information on prompt processing by the + * AI. + */ + public PromptMetadata getPromptMetadata() { + PromptMetadata promptMetadata = this.promptMetadata; + return promptMetadata != null ? promptMetadata : PromptMetadata.empty(); + } + /** * Arbitrary model provider specific output */ public Map getProviderOutput() { - return Collections.unmodifiableMap(providerOutput); + return this.providerOutput; } /** * The run metadata information */ public Map getRunInfo() { - return Collections.unmodifiableMap(runInfo); + return this.runInfo; + } + + /** + * Builder method used to include {@link PromptMetadata} returned in the AI response + * when processing the prompt. + * @param promptMetadata {@link PromptMetadata} returned by the AI in the response + * when processing the prompt. + * @return this {@link AiResponse}. + * @see #getPromptMetadata() + */ + public AiResponse withPromptMetadata(@Nullable PromptMetadata promptMetadata) { + this.promptMetadata = promptMetadata; + return this; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/client/Generation.java b/spring-ai-core/src/main/java/org/springframework/ai/client/Generation.java index 73be362bb38..06507b3845a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/client/Generation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/client/Generation.java @@ -19,6 +19,9 @@ import java.util.Collections; import java.util.Map; +import org.springframework.ai.metadata.ChoiceMetadata; +import org.springframework.lang.Nullable; + public class Generation { // Just text for now @@ -26,13 +29,15 @@ public class Generation { private Map info; + private ChoiceMetadata choiceMetadata; + public Generation(String text) { this(text, Collections.emptyMap()); } public Generation(String text, Map info) { this.text = text; - this.info = info; + this.info = Map.copyOf(info); } public String getText() { @@ -40,7 +45,17 @@ public String getText() { } public Map getInfo() { - return Collections.unmodifiableMap(this.info); + return this.info; + } + + public ChoiceMetadata getChoiceMetadata() { + ChoiceMetadata choiceMetadata = this.choiceMetadata; + return choiceMetadata != null ? choiceMetadata : ChoiceMetadata.NULL; + } + + public Generation withChoiceMetadata(@Nullable ChoiceMetadata choiceMetadata) { + this.choiceMetadata = choiceMetadata; + return this; } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractRateLimit.java b/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractRateLimit.java new file mode 100644 index 00000000000..9b0b3a80315 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractRateLimit.java @@ -0,0 +1,59 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.metadata; + +import java.time.Duration; + +/** + * Abstract base class used as a foundation for implementing {@link RateLimit}. + * + * @author John Blum + * @since 0.7.0 + */ +public abstract class AbstractRateLimit implements RateLimit { + + @Override + public Long getRequestsLimit() { + return 0L; + } + + @Override + public Long getRequestsRemaining() { + return 0L; + } + + @Override + public Duration getRequestsReset() { + return Duration.ZERO; + } + + @Override + public Long getTokensLimit() { + return 0L; + } + + @Override + public Long getTokensRemaining() { + return 0L; + } + + @Override + public Duration getTokensReset() { + return Duration.ZERO; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractUsage.java b/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractUsage.java new file mode 100644 index 00000000000..2adad4b9de3 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/metadata/AbstractUsage.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.metadata; + +/** + * Abstract base class used as a foundation for implementing {@link Usage}. + * + * @author John Blum + * @since 0.7.0 + */ +public abstract class AbstractUsage implements Usage { + + @Override + public Long getPromptTokens() { + return 0L; + } + + @Override + public Long getGenerationTokens() { + return 0L; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/ChoiceMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/metadata/ChoiceMetadata.java new file mode 100644 index 00000000000..9c15fbdb226 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/metadata/ChoiceMetadata.java @@ -0,0 +1,74 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.metadata; + +import org.springframework.lang.Nullable; + +/** + * Abstract Data Type (ADT) encapsulating information on the completion choices in the AI + * response. + * + * @author John Blum + * @since 0.7.0 + */ +public interface ChoiceMetadata { + + ChoiceMetadata NULL = ChoiceMetadata.from(null, null); + + /** + * Factory method used to construct a new {@link ChoiceMetadata} from the given + * {@link String finish reason} and content filter metadata. + * @param finishReason {@link String} contain the reason for the choice completion. + * @param contentFilterMetadata underlying AI provider metadata for filtering applied + * to generation content. + * @return a new {@link ChoiceMetadata} from the given {@link String finish reason} + * and content filter metadata. + */ + static ChoiceMetadata from(String finishReason, Object contentFilterMetadata) { + return new ChoiceMetadata() { + + @Override + @SuppressWarnings("unchecked") + public T getContentFilterMetadata() { + return (T) contentFilterMetadata; + } + + @Override + public String getFinishReason() { + return finishReason; + } + }; + } + + /** + * Returns the underlying AI provider metadata for filtering applied to generation + * content. + * @param {@link Class Type} used to cast the filtered content metadata into the + * AI provider-specific type. + * @return the underlying AI provider metadata for filtering applied to generation + * content. + */ + @Nullable + T getContentFilterMetadata(); + + /** + * Get the {@link String reason} this choice completed for the generation. + * @return the {@link String reason} this choice completed for the generation. + */ + String getFinishReason(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/GenerationMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/metadata/GenerationMetadata.java new file mode 100644 index 00000000000..3bd641eca24 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/metadata/GenerationMetadata.java @@ -0,0 +1,49 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.metadata; + +/** + * Abstract Data Type (ADT) modeling common AI provider metadata returned in an AI + * response. + * + * @author John Blum + * @since 0.7.0 + */ +public interface GenerationMetadata { + + GenerationMetadata NULL = new GenerationMetadata() { + }; + + /** + * Returns AI provider specific metadata on rate limits. + * @return AI provider specific metadata on rate limits. + * @see RateLimit + */ + default RateLimit getRateLimit() { + return RateLimit.NULL; + } + + /** + * Returns AI provider specific metadata on API usage. + * @return AI provider specific metadata on API usage. + * @see Usage + */ + default Usage getUsage() { + return Usage.NULL; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/PromptMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/metadata/PromptMetadata.java new file mode 100644 index 00000000000..302481b6e38 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/metadata/PromptMetadata.java @@ -0,0 +1,136 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.metadata; + +import java.util.Arrays; +import java.util.Optional; +import java.util.stream.StreamSupport; + +import org.springframework.util.Assert; + +/** + * Abstract Data Type (ADT) modeling metadata gathered by the AI during request + * processing. + * + * @author John Blum + * @since 0.7.0 + */ +@FunctionalInterface +public interface PromptMetadata extends Iterable { + + /** + * Factory method used to create empty {@link PromptMetadata} when the information is + * not supplied by the AI provider. + * @return empty {@link PromptMetadata}. + */ + static PromptMetadata empty() { + return of(); + } + + /** + * Factory method used to create a new {@link PromptMetadata} composed of an array of + * {@link PromptFilterMetadata}. + * @param array array of {@link PromptFilterMetadata} used to compose the + * {@link PromptMetadata}. + * @return a new {@link PromptMetadata} composed of an array of + * {@link PromptFilterMetadata}. + */ + static PromptMetadata of(PromptFilterMetadata... array) { + return of(Arrays.asList(array)); + } + + /** + * Factory method used to create a new {@link PromptMetadata} composed of an + * {@link Iterable} of {@link PromptFilterMetadata}. + * @param iterable {@link Iterable} of {@link PromptFilterMetadata} used to compose + * the {@link PromptMetadata}. + * @return a new {@link PromptMetadata} composed of an {@link Iterable} of + * {@link PromptFilterMetadata}. + */ + static PromptMetadata of(Iterable iterable) { + Assert.notNull(iterable, "An Iterable of PromptFilterMetadata must not be null"); + return iterable::iterator; + } + + /** + * Returns an {@link Optional} {@link PromptFilterMetadata} at the given index. + * @param promptIndex index of the {@link PromptFilterMetadata} contained in this + * {@link PromptMetadata}. + * @return {@link Optional} {@link PromptFilterMetadata} at the given index. + * @throws IllegalArgumentException if the prompt index is less than 0. + */ + default Optional findByPromptIndex(int promptIndex) { + + Assert.isTrue(promptIndex > -1, "Prompt index [%d] must be greater than equal to 0".formatted(promptIndex)); + + return StreamSupport.stream(this.spliterator(), false) + .filter(promptFilterMetadata -> promptFilterMetadata.getPromptIndex() == promptIndex) + .findFirst(); + } + + /** + * Abstract Data Type (ADT) modeling filter metadata for all prompts sent during an AI + * request. + */ + interface PromptFilterMetadata { + + /** + * Factory method used to construct a new {@link PromptFilterMetadata} with the + * given prompt index and content filter metadata. + * @param promptIndex index of the prompt filter metadata contained in the AI + * response. + * @param contentFilterMetadata underlying AI provider metadata for filtering + * applied to prompt content. + * @return a new instance of {@link PromptFilterMetadata} with the given prompt + * index and content filter metadata. + */ + static PromptFilterMetadata from(int promptIndex, Object contentFilterMetadata) { + + return new PromptFilterMetadata() { + + @Override + public int getPromptIndex() { + return promptIndex; + } + + @Override + @SuppressWarnings("unchecked") + public T getContentFilterMetadata() { + return (T) contentFilterMetadata; + } + }; + } + + /** + * Index of the prompt filter metadata contained in the AI response. + * @return an {@link Integer index} fo the prompt filter metadata contained in the + * AI response. + */ + int getPromptIndex(); + + /** + * Returns the underlying AI provider metadata for filtering applied to prompt + * content. + * @param {@link Class Type} used to cast the filtered content metadata into + * the AI provider-specific type. + * @return the underlying AI provider metadata for filtering applied to prompt + * content. + */ + T getContentFilterMetadata(); + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/RateLimit.java b/spring-ai-core/src/main/java/org/springframework/ai/metadata/RateLimit.java new file mode 100644 index 00000000000..40bb933d03b --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/metadata/RateLimit.java @@ -0,0 +1,87 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.metadata; + +import java.time.Duration; + +/** + * Abstract Data Type (ADT) encapsulating metadata from an AI provider's API rate limits + * granted to the API key in use and the API key's current balance. + * + * @author John Blum + * @since 0.7.0 + */ +public interface RateLimit { + + RateLimit NULL = new AbstractRateLimit() { + }; + + /** + * Returns the maximum number of requests that are permitted before exhausting the + * rate limit. + * @return an {@link Long} with the maximum number of requests that are permitted + * before exhausting the rate limit. + * @see #getRequestsRemaining() + */ + Long getRequestsLimit(); + + /** + * Returns the remaining number of requests that are permitted before exhausting the + * {@link #getRequestsLimit() rate limit}. + * @return an {@link Long} with the remaining number of requests that are permitted + * before exhausting the {@link #getRequestsLimit() rate limit}. + * @see #getRequestsLimit() + */ + Long getRequestsRemaining(); + + /** + * Returns the {@link Duration time} until the rate limit (based on requests) resets + * to its {@link #getRequestsLimit() initial state}. + * @return a {@link Duration} representing the time until the rate limit (based on + * requests) resets to its {@link #getRequestsLimit() initial state}. + * @see #getRequestsLimit() + */ + Duration getRequestsReset(); + + /** + * Returns the maximum number of tokens that are permitted before exhausting the rate + * limit. + * @return an {@link Long} with the maximum number of tokens that are permitted before + * exhausting the rate limit. + * @see #getTokensRemaining() + */ + Long getTokensLimit(); + + /** + * Returns the remaining number of tokens that are permitted before exhausting the + * {@link #getTokensLimit() rate limit}. + * @return an {@link Long} with the remaining number of tokens that are permitted + * before exhausting the {@link #getTokensLimit() rate limit}. + * @see #getTokensLimit() + */ + Long getTokensRemaining(); + + /** + * Returns the {@link Duration time} until the rate limit (based on tokens) resets to + * its {@link #getTokensLimit() initial state}. + * @return a {@link Duration} with the time until the rate limit (based on tokens) + * resets to its {@link #getTokensLimit() initial state}. + * @see #getTokensLimit() + */ + Duration getTokensReset(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/metadata/Usage.java b/spring-ai-core/src/main/java/org/springframework/ai/metadata/Usage.java new file mode 100644 index 00000000000..7179d5e127e --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/metadata/Usage.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.metadata; + +/** + * Abstract Data Type (ADT) encapsulating metadata on the usage of an AI provider's API + * per AI request. + * + * @author John Blum + * @since 0.7.0 + */ +public interface Usage { + + Usage NULL = new AbstractUsage() { + }; + + /** + * Returns the number of tokens used in the {@literal prompt} of the AI request. + * @return an {@link Long} with the number of tokens used in the {@literal prompt} of + * the AI request. + * @see #getGenerationTokens() + */ + Long getPromptTokens(); + + /** + * Returns the number of tokens returned in the {@literal generation (aka completion)} + * of the AI's response. + * @return an {@link Long} with the number of tokens returned in the + * {@literal generation (aka completion)} of the AI's response. + * @see #getPromptTokens() + */ + Long getGenerationTokens(); + + /** + * Return the total number of tokens from both the {@literal prompt} of an AI request + * and {@literal generation} of the AI's response. + * @return the total number of tokens from both the {@literal prompt} of an AI request + * and {@literal generation} of the AI's response. + * @see #getPromptTokens() + * @see #getGenerationTokens() + */ + default Long getTotalTokens() { + Long promptTokens = getPromptTokens(); + promptTokens = promptTokens != null ? promptTokens : 0; + Long completionTokens = getGenerationTokens(); + completionTokens = completionTokens != null ? completionTokens : 0; + return promptTokens + completionTokens; + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java b/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java new file mode 100644 index 00000000000..792b093842a --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/metadata/PromptMetadataTests.java @@ -0,0 +1,107 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.metadata; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.metadata.PromptMetadata.PromptFilterMetadata; + +/** + * Unit Tests for {@link PromptMetadata}. + * + * @author John Blum + * @since 0.7.0 + */ +public class PromptMetadataTests { + + private PromptFilterMetadata mockPromptFilterMetadata(int index) { + PromptFilterMetadata mockPromptFilterMetadata = mock(PromptFilterMetadata.class); + doReturn(index).when(mockPromptFilterMetadata).getPromptIndex(); + return mockPromptFilterMetadata; + } + + @Test + void emptyPromptMetadata() { + + PromptMetadata empty = PromptMetadata.empty(); + + assertThat(empty).isNotNull(); + assertThat(empty).isEmpty(); + } + + @Test + void promptMetadataWithOneFilter() { + + PromptFilterMetadata mockPromptFilterMetadata = mockPromptFilterMetadata(0); + PromptMetadata promptMetadata = PromptMetadata.of(mockPromptFilterMetadata); + + assertThat(promptMetadata).isNotNull(); + assertThat(promptMetadata).containsExactly(mockPromptFilterMetadata); + } + + @Test + void promptMetadataWithTwoFilters() { + + PromptFilterMetadata mockPromptFilterMetadataOne = mockPromptFilterMetadata(0); + PromptFilterMetadata mockPromptFilterMetadataTwo = mockPromptFilterMetadata(1); + PromptMetadata promptMetadata = PromptMetadata.of(mockPromptFilterMetadataOne, mockPromptFilterMetadataTwo); + + assertThat(promptMetadata).isNotNull(); + assertThat(promptMetadata).containsExactly(mockPromptFilterMetadataOne, mockPromptFilterMetadataTwo); + } + + @Test + void findByPromptIndex() { + + PromptFilterMetadata mockPromptFilterMetadataOne = mockPromptFilterMetadata(0); + PromptFilterMetadata mockPromptFilterMetadataTwo = mockPromptFilterMetadata(1); + PromptMetadata promptMetadata = PromptMetadata.of(mockPromptFilterMetadataOne, mockPromptFilterMetadataTwo); + + assertThat(promptMetadata).isNotNull(); + assertThat(promptMetadata).containsExactly(mockPromptFilterMetadataOne, mockPromptFilterMetadataTwo); + assertThat(promptMetadata.findByPromptIndex(1).orElse(null)).isEqualTo(mockPromptFilterMetadataTwo); + assertThat(promptMetadata.findByPromptIndex(0).orElse(null)).isEqualTo(mockPromptFilterMetadataOne); + } + + @Test + void findByPromptIndexWithNoFilters() { + assertThat(PromptMetadata.empty().findByPromptIndex(0)).isNotPresent(); + } + + @Test + void findByInvalidPromptIndex() { + + assertThatIllegalArgumentException().isThrownBy(() -> PromptMetadata.empty().findByPromptIndex(-1)) + .withMessage("Prompt index [-1] must be greater than equal to 0") + .withNoCause(); + } + + @Test + void fromPromptIndexAndContentFilterMetadata() { + + PromptFilterMetadata promptFilterMetadata = PromptFilterMetadata.from(1, "{ content-sentiment: 'SAFE' }"); + + assertThat(promptFilterMetadata).isNotNull(); + assertThat(promptFilterMetadata.getPromptIndex()).isOne(); + assertThat(promptFilterMetadata.getContentFilterMetadata()).isEqualTo("{ content-sentiment: 'SAFE' }"); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java b/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java new file mode 100644 index 00000000000..0507d4c7f16 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/metadata/UsageTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.metadata; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.doCallRealMethod; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import org.junit.jupiter.api.Test; + +/** + * Unit Tests for {@link Usage}. + * + * @author John Blum + * @since 0.7.0 + */ +public class UsageTests { + + private Usage mockUsage(Long promptTokens, Long generationTokens) { + Usage mockUsage = mock(Usage.class); + doReturn(promptTokens).when(mockUsage).getPromptTokens(); + doReturn(generationTokens).when(mockUsage).getGenerationTokens(); + doCallRealMethod().when(mockUsage).getTotalTokens(); + return mockUsage; + } + + private void verifyUsage(Usage usage) { + verify(usage, times(1)).getTotalTokens(); + verify(usage, times(1)).getPromptTokens(); + verify(usage, times(1)).getGenerationTokens(); + verifyNoMoreInteractions(usage); + } + + @Test + void totalTokensIsZeroWhenNoPromptOrGenerationMetadataPresent() { + + Usage usage = mockUsage(null, null); + + assertThat(usage.getTotalTokens()).isZero(); + verifyUsage(usage); + } + + @Test + void totalTokensEqualsPromptTokens() { + + Usage usage = mockUsage(10L, null); + + assertThat(usage.getTotalTokens()).isEqualTo(10L); + verifyUsage(usage); + } + + @Test + void totalTokensEqualsGenerationTokens() { + + Usage usage = mockUsage(null, 15L); + + assertThat(usage.getTotalTokens()).isEqualTo(15L); + verifyUsage(usage); + } + + @Test + void totalTokensEqualsPromptTokensPlusGenerationTokens() { + + Usage usage = mockUsage(10L, 15L); + + assertThat(usage.getTotalTokens()).isEqualTo(25L); + verifyUsage(usage); + } + +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/aimetadata.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/aimetadata.adoc new file mode 100644 index 00000000000..fe105a06786 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/aimetadata.adoc @@ -0,0 +1,139 @@ +[[AiMetadata]] += AI metadata + +Use of an AI, such as OpenAI's ChatGPT, consumes resources and generates metrics returned by the AI provider based on the usage and requests made to the AI through the API. +Consumption is typically in the form of requests made or tokens used in a given timeframe, such as monthly, that AI providers use to measure this consumption and reset limits. +Your rate limits are directly determined by your plan when you signed up with your AI provider. For instance, you can review details on OpenAI's https://platform.openai.com/docs/guides/rate-limits?context=tier-free[rate limits] and https://openai.com/pricing#language-models[plans] by following the links. + +To help garner insight into your AI (model) consumption and general usage, Spring AI provides an API to introspect the metadata that is returned by AI providers in their APIs. + +Spring AI defines 3 primary interfaces to examine these metrics: `GenerationMetadata`, `RateLimit` and `Usage`. All of these interface can be accessed programmatically from the `AiResponse` returned and initiated by an AI request. + +[[AiMetadata-GenerationMetadata]] +== `GenerationMetadata` interface + +The `GenerationMetadata` interface is defined as: + +.GenerationMetadata interface +[source,java] +---- +interface GenerationMetadata { + + default RateLimit getRateLimit() { + return RateLimit.NULL; + } + + default Usage getUsage() { + return Usage.NULL; + } + +} +---- + +An instance of `GenerationMetadata` is automatically created by Spring AI when an AI request is made through the AI provider's API and an AI response is returned. You can get access to the AI provider metadata from the `AiResponse` using: + +.Get access to `GenerationMetadata` from `AiResponse` +[source,java] +---- +@Service +class MyService { + + ApplicationObjectType askTheAi(ServiceRequest request) { + + Prompt prompt = createPrompt(request); + + AiResponse response = aiClient.generate(prompt) + + // Process the AI response + + GenerationMetadata metadata = response.getMetadata(); + + // Inspect the AI metadata returned in the AI response of the AI providers API + + Long totalTokensUsedInAiPromptAndResponse = metadata.getUsage().getTotalTokens(); + + // Act on this information somehow + } +} +---- + +You might imagine that you can rate limit your own Spring applications using AI, or restrict `Prompt` sizes, which affect your token usage, in an automated, intelligent and realtime manner. + +Minimally, you can simply gather these metrics to monitor and report on your consumption. + +[[AiMetadata-RateLimit]] +== RateLimit + +The `RateLimit` interface provides access to actual information returned by an AI provider on your API usage when making AI requests. + +.`RateLimit` interface +[source,java] +---- +interface RateLimit { + + Long getRequestsLimit(); + + Long getRequestsRemaining(); + + Duration getRequestsReset(); + + Long getTokensLimit(); + + Long getTokensRemaining(); + + Duration getTokensReset(); + +} +---- + +`requestsLimit` and `requestsRemaining` let you know how many AI requests, based on the AI provider plan you chose when you signed up, that you can make in total along with your remaining balance within the given timeframe. `requestsReset` returns a `Duration` of time before the timeframe expires and your limits reset based on your chosen plan. + +The methods for `tokensLimit`, `tokensRemaining` and `tokensReset` are similar to the methods for requests, but focus on token limits, balance and resets instead. + +The `RateLimit` instance can be acquired from the `GenerationMetadata`, like so: + +.Get access to `RateLimit` from `GenerationMetadata` +[source,java] +---- +RateLimit rateLimit = generationMetadata.getRateLimit(); + +Long tokensRemaining = rateLimit.getTokensRemaining(); + +// do something interesting with the RateLimit metadata +---- + +For AI providers like OpenAI, the rate limit metadata is returned in https://platform.openai.com/docs/guides/rate-limits/rate-limits-in-headers[HTTP headers] from their (REST) API accessible through HTTP clients, like OkHttp. + +Because this can be potentially a costly operation, the collection of rate limit AI metadata must be explicitly enabled. You can enable this collection with a Spring AI property in Spring Boot application.properties; for example: + +.Enable API rate limit collection from AI metadata +[source,properties] +---- +# Spring Boot application.properties +spring.ai.openai.metadata.rate-limit-metrics-enabled=true +---- + +[[AiMetadata-Usage]] +== Usage + +As shown <>, `Usage` data can be obtained from the `GenerationMetadata` object. The `Usage` interface is defined as: + +.`Usage` interface +[source,java] +---- +interface Usage { + + Long getPromptTokens(); + + Long getGenerationTokens(); + + default Long getTotalTokens() { + return getPromptTokens() + getGenerationTokens(); + } + +} +---- + +The method names are self-explanatory, but tells you the tokens that the AI required to process the `Prompt` and generate a response. + +`totalTokens` is the sum of `promptTokens` and `generationTokens`. Spring AI computes this by default, but the information is returned in the AI response from OpenAI. diff --git a/spring-ai-openai/pom.xml b/spring-ai-openai/pom.xml index be1a8f2a3df..7aa476df606 100644 --- a/spring-ai-openai/pom.xml +++ b/spring-ai-openai/pom.xml @@ -27,6 +27,16 @@ ${project.parent.version} + + io.rest-assured + json-path + + + + com.squareup.okhttp3 + okhttp + + com.theokanning.openai-gpt3-java service @@ -63,6 +73,24 @@ test + + jakarta.servlet + jakarta.servlet-api + test + + + + org.springframework + spring-webmvc + test + + + + com.squareup.okhttp3 + mockwebserver + test + + diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiClient.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiClient.java index 630685e1219..e9c38828bad 100644 --- a/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiClient.java +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiClient.java @@ -16,20 +16,24 @@ package org.springframework.ai.openai.client; -import com.theokanning.openai.completion.chat.ChatCompletionChoice; -import com.theokanning.openai.completion.chat.ChatCompletionRequest; -import com.theokanning.openai.completion.chat.ChatMessage; -import com.theokanning.openai.service.OpenAiService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.client.AiClient; import org.springframework.ai.client.AiResponse; import org.springframework.ai.client.Generation; +import org.springframework.ai.metadata.ChoiceMetadata; +import org.springframework.ai.openai.metadata.OpenAiGenerationMetadata; import org.springframework.ai.prompt.Prompt; import org.springframework.ai.prompt.messages.Message; import org.springframework.ai.prompt.messages.MessageType; import org.springframework.util.Assert; +import com.theokanning.openai.completion.chat.ChatCompletionChoice; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import com.theokanning.openai.completion.chat.ChatCompletionResult; +import com.theokanning.openai.completion.chat.ChatMessage; +import com.theokanning.openai.service.OpenAiService; + import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -101,6 +105,21 @@ public AiResponse generate(Prompt prompt) { return getAiResponse(chatCompletionRequest); } + private AiResponse getAiResponse(ChatCompletionRequest chatCompletionRequest) { + logger.trace("ChatMessages: {}", chatCompletionRequest.getMessages()); + ChatCompletionResult chatCompletionResult = this.openAiService.createChatCompletion(chatCompletionRequest); + List chatCompletionChoices = chatCompletionResult.getChoices(); + logger.trace("ChatCompletionChoices: {}", chatCompletionChoices); + List generations = new ArrayList<>(); + for (ChatCompletionChoice chatCompletionChoice : chatCompletionChoices) { + ChatMessage chatMessage = chatCompletionChoice.getMessage(); + Generation generation = new Generation(chatMessage.getContent(), Map.of("role", chatMessage.getRole())) + .withChoiceMetadata(ChoiceMetadata.from(chatCompletionChoice.getFinishReason(), null)); + generations.add(generation); + } + return new AiResponse(generations, OpenAiGenerationMetadata.from(chatCompletionResult)); + } + private ChatCompletionRequest getChatCompletionRequest(String text) { List chatMessages = List.of(new ChatMessage("user", text)); @@ -116,27 +135,6 @@ private ChatCompletionRequest getChatCompletionRequest(String text) { return chatCompletionRequest; } - private AiResponse getAiResponse(ChatCompletionRequest chatCompletionRequest) { - - List generations = new ArrayList<>(); - logger.trace("ChatMessages: {}", chatCompletionRequest.getMessages()); - - List chatCompletionChoices = this.openAiService - .createChatCompletion(chatCompletionRequest) - .getChoices(); - logger.trace("ChatCompletionChoice: {}", chatCompletionChoices); - - for (ChatCompletionChoice chatCompletionChoice : chatCompletionChoices) { - ChatMessage chatMessage = chatCompletionChoice.getMessage(); - // TODO investigate mapping of additional metadata/runtime info to the general - // model. - Generation generation = new Generation(chatMessage.getContent(), Map.of("role", chatMessage.getRole())); - generations.add(generation); - } - - return new AiResponse(generations); - } - private String getResponse(ChatCompletionRequest chatCompletionRequest) { StringBuilder builder = new StringBuilder(); diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiGenerationMetadata.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiGenerationMetadata.java new file mode 100644 index 00000000000..0b3222de82f --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiGenerationMetadata.java @@ -0,0 +1,93 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.metadata; + +import com.theokanning.openai.completion.chat.ChatCompletionResult; + +import org.springframework.ai.metadata.GenerationMetadata; +import org.springframework.ai.metadata.RateLimit; +import org.springframework.ai.metadata.Usage; +import org.springframework.ai.openai.metadata.support.OpenAiHttpResponseHeadersInterceptor; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link GenerationMetadata} implementation for {@literal OpenAI}. + * + * @author John Blum + * @see org.springframework.ai.metadata.GenerationMetadata + * @see org.springframework.ai.metadata.RateLimit + * @see org.springframework.ai.metadata.Usage + * @since 0.7.0 + */ +public class OpenAiGenerationMetadata implements GenerationMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }"; + + public static OpenAiGenerationMetadata from(ChatCompletionResult result) { + Assert.notNull(result, "OpenAI ChatCompletionResult must not be null"); + OpenAiUsage usage = OpenAiUsage.from(result.getUsage()); + OpenAiGenerationMetadata generationMetadata = new OpenAiGenerationMetadata(result.getId(), usage); + OpenAiHttpResponseHeadersInterceptor.applyTo(generationMetadata); + return generationMetadata; + } + + private final String id; + + @Nullable + private RateLimit rateLimit; + + private final Usage usage; + + protected OpenAiGenerationMetadata(String id, OpenAiUsage usage) { + this(id, usage, null); + } + + protected OpenAiGenerationMetadata(String id, OpenAiUsage usage, @Nullable OpenAiRateLimit rateLimit) { + this.id = id; + this.usage = usage; + this.rateLimit = rateLimit; + } + + public String getId() { + return this.id; + } + + @Override + @Nullable + public RateLimit getRateLimit() { + RateLimit rateLimit = this.rateLimit; + return rateLimit != null ? rateLimit : RateLimit.NULL; + } + + @Override + public Usage getUsage() { + Usage usage = this.usage; + return usage != null ? usage : Usage.NULL; + } + + public OpenAiGenerationMetadata withRateLimit(RateLimit rateLimit) { + this.rateLimit = rateLimit; + return this; + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getUsage(), getRateLimit()); + } + +} diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java new file mode 100644 index 00000000000..4a0697cb397 --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiRateLimit.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.metadata; + +import java.time.Duration; + +import org.springframework.ai.metadata.RateLimit; + +/** + * {@link RateLimit} implementation for {@literal OpenAI}. + * + * @author John Blum + * @since 0.7.0 + * @see Rate + * limits in headers + */ +public class OpenAiRateLimit implements RateLimit { + + private static final String RATE_LIMIT_STRING = "{ @type: %1$s, requestsLimit: %2$s, requestsRemaining: %3$s, requestsReset: %4$s, tokensLimit: %5$s; tokensRemaining: %6$s; tokensReset: %7$s }"; + + private final Long requestsLimit; + + private final Long requestsRemaining; + + private final Long tokensLimit; + + private final Long tokensRemaining; + + private final Duration requestsReset; + + private final Duration tokensReset; + + public OpenAiRateLimit(Long requestsLimit, Long requestsRemaining, Duration requestsReset, Long tokensLimit, + Long tokensRemaining, Duration tokensReset) { + + this.requestsLimit = requestsLimit; + this.requestsRemaining = requestsRemaining; + this.requestsReset = requestsReset; + this.tokensLimit = tokensLimit; + this.tokensRemaining = tokensRemaining; + this.tokensReset = tokensReset; + } + + @Override + public Long getRequestsLimit() { + return this.requestsLimit; + } + + @Override + public Long getTokensLimit() { + return this.tokensLimit; + } + + @Override + public Long getRequestsRemaining() { + return this.requestsRemaining; + } + + @Override + public Long getTokensRemaining() { + return this.tokensRemaining; + } + + @Override + public Duration getRequestsReset() { + return this.requestsReset; + } + + @Override + public Duration getTokensReset() { + return this.tokensReset; + } + + @Override + public String toString() { + return RATE_LIMIT_STRING.formatted(getClass().getName(), getRequestsLimit(), getRequestsRemaining(), + getRequestsReset(), getTokensLimit(), getTokensRemaining(), getTokensReset()); + } + +} diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java new file mode 100644 index 00000000000..02b02de6e7d --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.metadata; + +import org.springframework.ai.metadata.Usage; +import org.springframework.util.Assert; + +/** + * {@link Usage} implementation for {@literal OpenAI}. + * + * @author John Blum + * @since 0.7.0 + * @see Completion + * Object + */ +public class OpenAiUsage implements Usage { + + public static OpenAiUsage from(com.theokanning.openai.Usage usage) { + return new OpenAiUsage(usage); + } + + private final com.theokanning.openai.Usage usage; + + protected OpenAiUsage(com.theokanning.openai.Usage usage) { + Assert.notNull(usage, "OpenAI Usage must not be null"); + this.usage = usage; + } + + protected com.theokanning.openai.Usage getUsage() { + return this.usage; + } + + @Override + public Long getPromptTokens() { + return getUsage().getPromptTokens(); + } + + @Override + public Long getGenerationTokens() { + return getUsage().getCompletionTokens(); + } + + @Override + public Long getTotalTokens() { + return getUsage().getTotalTokens(); + } + + @Override + public String toString() { + return getUsage().toString(); + } + +} diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiApiResponseHeaders.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiApiResponseHeaders.java new file mode 100644 index 00000000000..488aaa70d97 --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiApiResponseHeaders.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.metadata.support; + +/** + * {@link Enum Enumeration} of {@literal OpenAI} API response headers. + * + * @author John Blum + * @since 0.7.0 + */ +public enum OpenAiApiResponseHeaders { + + REQUESTS_LIMIT_HEADER("x-ratelimit-limit-requests", "Total number of requests allowed within timeframe."), + REQUESTS_REMAINING_HEADER("x-ratelimit-remaining-requests", "Remaining number of requests available in timeframe."), + REQUESTS_RESET_HEADER("x-ratelimit-reset-requests", "Duration of time until the number of requests reset."), + TOKENS_RESET_HEADER("x-ratelimit-reset-tokens", "Total number of tokens allowed within timeframe."), + TOKENS_LIMIT_HEADER("x-ratelimit-limit-tokens", "Remaining number of tokens available in timeframe."), + TOKENS_REMAINING_HEADER("x-ratelimit-remaining-tokens", "Duration of time until the number of tokens reset."); + + private final String headerName; + + private final String description; + + OpenAiApiResponseHeaders(String headerName, String description) { + this.headerName = headerName; + this.description = description; + } + + public String getName() { + return this.headerName; + } + + public String getDescription() { + return this.description; + } + +} diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiHttpResponseHeadersInterceptor.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiHttpResponseHeadersInterceptor.java new file mode 100644 index 00000000000..473c4e19618 --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/support/OpenAiHttpResponseHeadersInterceptor.java @@ -0,0 +1,249 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.metadata.support; + +import static org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders.REQUESTS_LIMIT_HEADER; +import static org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders.REQUESTS_REMAINING_HEADER; +import static org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders.REQUESTS_RESET_HEADER; +import static org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders.TOKENS_LIMIT_HEADER; +import static org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER; +import static org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders.TOKENS_RESET_HEADER; + +import java.io.IOException; +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.WeakHashMap; +import java.util.function.Predicate; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import io.restassured.path.json.JsonPath; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.metadata.RateLimit; +import org.springframework.ai.openai.metadata.OpenAiGenerationMetadata; +import org.springframework.ai.openai.metadata.OpenAiRateLimit; +import org.springframework.http.HttpHeaders; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.ResponseBody; + +/** + * OkHttp {@link Interceptor} implementation used capture the AI HTTP response headers + * from {@literal OpenAI} API. + * + * @author John Blum + * @see okhttp3.Interceptor + * @since 0.7.0 + */ +public class OpenAiHttpResponseHeadersInterceptor implements Interceptor { + + private static final Map cache = Collections.synchronizedMap(new WeakHashMap<>()); + + public static void applyTo(OpenAiGenerationMetadata metadata) { + + String id = metadata.getId(); + + synchronized (cache) { + metadata.withRateLimit(cache.get(id)); + cache.remove(id); + } + } + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + @Override + public Response intercept(Chain chain) throws IOException { + + Request request = chain.request(); + Response response = chain.proceed(request); + + cacheAiResponseHeaders(response); + + return response; + } + + protected Logger getLogger() { + return this.logger; + } + + private RateLimit cacheAiResponseHeaders(Response response) { + + String id = parseAiResponseId(response); + + OpenAiRateLimit rateLimit = StringUtils.hasText(id) ? cache.computeIfAbsent(id, key -> { + + Long requestsLimit = getHeaderAsLong(response, REQUESTS_LIMIT_HEADER.getName()); + Long requestsRemaining = getHeaderAsLong(response, REQUESTS_REMAINING_HEADER.getName()); + Long tokensLimit = getHeaderAsLong(response, TOKENS_LIMIT_HEADER.getName()); + Long tokensRemaining = getHeaderAsLong(response, TOKENS_REMAINING_HEADER.getName()); + + Duration requestsReset = getHeaderAsDuration(response, REQUESTS_RESET_HEADER.getName()); + Duration tokensReset = getHeaderAsDuration(response, TOKENS_RESET_HEADER.getName()); + + return new OpenAiRateLimit(requestsLimit, requestsRemaining, requestsReset, tokensLimit, tokensRemaining, + tokensReset); + }) : null; + + return rateLimit; + } + + private Duration getHeaderAsDuration(Response response, String headerName) { + String headerValue = response.header(headerName); + return DurationFormatter.TIME_UNIT.parse(headerValue); + } + + private Long getHeaderAsLong(Response response, String headerName) { + String headerValue = response.header(headerName); + return parseLong(headerName, headerValue); + } + + private String parseAiResponseId(Response response) { + + try { + long contentLength = resolveContentLength(response); + ResponseBody responseBody = response.peekBody(contentLength); + String bodyContent = responseBody.string(); + String id = JsonPath.with(bodyContent).getString("id"); + return id; + } + catch (Exception e) { + getLogger().warn("Unable to get AI response body as a String: {}", e.getMessage()); + return null; + } + } + + private Long parseLong(String headerName, String headerValue) { + + if (StringUtils.hasText(headerValue)) { + try { + return Long.parseLong(headerValue.trim()); + } + catch (NumberFormatException e) { + getLogger().warn("Value [{}] for HTTP header [{}] is not valid: {}", headerName, headerValue, + e.getMessage()); + } + } + + return null; + } + + private long resolveContentLength(Response response) { + return getHeaderAsLong(response, HttpHeaders.CONTENT_LENGTH); + } + + enum DurationFormatter { + + TIME_UNIT("\\d+[a-zA-Z]{1,2}"); + + private final Pattern pattern; + + DurationFormatter(String durationPattern) { + this.pattern = Pattern.compile(durationPattern); + } + + public Duration parse(String text) { + + Assert.hasText(text, "Text [%s] to parse as a Duration must not be null or empty".formatted(text)); + + Matcher matcher = this.pattern.matcher(text); + Duration total = Duration.ZERO; + + while (matcher.find()) { + String value = matcher.group(); + total = total.plus(Unit.parseUnit(value).toDuration(value)); + } + + return total; + } + + enum Unit { + + NANOSECONDS("ns", "nanoseconds", ChronoUnit.NANOS), MICROSECONDS("us", "microseconds", ChronoUnit.MICROS), + MILLISECONDS("ms", "milliseconds", ChronoUnit.MILLIS), SECONDS("s", "seconds", ChronoUnit.SECONDS), + MINUTES("m", "minutes", ChronoUnit.MINUTES), HOURS("h", "hours", ChronoUnit.HOURS), + DAYS("d", "days", ChronoUnit.DAYS); + + private final String name; + + private final String symbol; + + private final ChronoUnit unit; + + Unit(String symbol, String name, ChronoUnit unit) { + this.symbol = symbol; + this.name = name; + this.unit = unit; + } + + static Unit parseUnit(String value) { + String symbol = parseSymbol(value); + return Arrays.stream(values()) + .filter(unit -> unit.getSymbol().equalsIgnoreCase(symbol)) + .findFirst() + .orElseThrow(() -> new IllegalStateException( + "Value [%s] does not contain a valid time unit".formatted(value))); + } + + private static String parse(String value, Predicate predicate) { + Assert.hasText(value, "Value [%s] must not be null or empty".formatted(value)); + StringBuilder builder = new StringBuilder(); + for (char character : value.toCharArray()) { + if (predicate.test(character)) { + builder.append(character); + } + } + return builder.toString(); + } + + private static String parseSymbol(String value) { + return parse(value, Character::isLetter); + } + + private static Long parseTime(String value) { + return Long.parseLong(parse(value, Character::isDigit)); + } + + public String getName() { + return this.name; + } + + public String getSymbol() { + return this.symbol; + } + + public ChronoUnit getUnit() { + return this.unit; + } + + public Duration toDuration(String value) { + return Duration.of(parseTime(value), getUnit()); + } + + } + + } + +} diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiMockTestConfiguration.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiMockTestConfiguration.java new file mode 100644 index 00000000000..35f2cbafca9 --- /dev/null +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiMockTestConfiguration.java @@ -0,0 +1,311 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai; + +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Optional; +import java.util.Queue; +import java.util.UUID; +import java.util.concurrent.ConcurrentLinkedDeque; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.theokanning.openai.client.OpenAiApi; +import com.theokanning.openai.service.OpenAiService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.openai.client.OpenAiClient; +import org.springframework.ai.openai.metadata.support.OpenAiHttpResponseHeadersInterceptor; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.context.SmartLifecycle; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Profile; +import org.springframework.lang.Nullable; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.RequestBuilder; +import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import okhttp3.HttpUrl; +import okhttp3.OkHttpClient; +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import okio.Buffer; +import retrofit2.Retrofit; +import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; +import retrofit2.converter.jackson.JacksonConverterFactory; + +/** + * {@link SpringBootConfiguration} for {@literal OpenAI's} API using mock objects. + *

+ * This test configuration allows Spring AI framework developers to mock OpenAI's API with + * Spring {@link MockMvc} and a test provided Spring Web MVC + * {@link org.springframework.web.bind.annotation.RestController}. + *

+ * This test configuration makes use of the OkHttp3 {@link MockWebServer} and + * {@link Dispatcher} to integrate with Spring {@link MockMvc}. + * + * @author John Blum + * @see okhttp3.mockwebserver.Dispatcher + * @see okhttp3.mockwebserver.MockWebServer + * @see org.springframework.boot.SpringBootConfiguration + * @see org.springframework.test.web.servlet.MockMvc + * @since 0.7.0 + */ +@SpringBootConfiguration +@Profile("spring-ai-openai-mocks") +@SuppressWarnings("unused") +public class OpenAiMockTestConfiguration { + + private static final Charset FALLBACK_CHARSET = StandardCharsets.UTF_8; + + private static final String SPRING_AI_API_PATH = "/spring-ai/api"; + + @Bean + MockWebServerFactoryBean mockWebServer(MockMvc mockMvc) { + MockWebServerFactoryBean factoryBean = new MockWebServerFactoryBean(); + factoryBean.setDispatcher(new MockMvcDispatcher(mockMvc)); + return factoryBean; + } + + @Bean + OpenAiService theoOpenAiService(MockWebServer webServer) { + + String apiKey = UUID.randomUUID().toString(); + Duration timeout = Duration.ofSeconds(60); + + ObjectMapper objectMapper = OpenAiService.defaultObjectMapper(); + + OkHttpClient httpClient = new OkHttpClient.Builder(OpenAiService.defaultClient(apiKey, timeout)) + .addInterceptor(new OpenAiHttpResponseHeadersInterceptor()) + .build(); + + HttpUrl baseUrl = webServer.url(SPRING_AI_API_PATH.concat("/")); + + Retrofit retrofit = new Retrofit.Builder().baseUrl(baseUrl) + .addConverterFactory(JacksonConverterFactory.create(objectMapper)) + .addCallAdapterFactory(RxJava2CallAdapterFactory.create()) + .client(httpClient) + .build(); + + OpenAiApi api = retrofit.create(OpenAiApi.class); + + return new OpenAiService(api); + } + + @Bean + OpenAiClient apiClient(OpenAiService openAiService) { + return new OpenAiClient(openAiService); + } + + static class MockMvcDispatcher extends Dispatcher { + + private final MockMvc mockMvc; + + MockMvcDispatcher(MockMvc mockMvc) { + Assert.notNull(mockMvc, "Spring MockMvc must not be null"); + this.mockMvc = mockMvc; + } + + protected MockMvc getMockMvc() { + return this.mockMvc; + } + + @Override + public MockResponse dispatch(RecordedRequest request) { + + try { + MvcResult result = getMockMvc().perform(requestBuilderFrom(request)) + .andExpect(status().isOk()) + .andReturn(); + + MockHttpServletResponse response = result.getResponse(); + + return mockResponseFrom(response); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + private RequestBuilder requestBuilderFrom(RecordedRequest request) { + + String requestMethod = request.getMethod(); + String requestPath = resolveRequestPath(request); + + URI uri = URI.create(requestPath); + + Buffer requestBody = request.getBody(); + + String content = requestBody.readUtf8(); + + return MockMvcRequestBuilders.request(requestMethod, uri).content(content); + } + + private String resolveRequestPath(RecordedRequest request) { + + String requestPath = request.getPath(); + String pavedRequestPath = StringUtils.hasText(requestPath) ? requestPath : "/"; + + return pavedRequestPath.startsWith(SPRING_AI_API_PATH) ? pavedRequestPath + : SPRING_AI_API_PATH.concat(pavedRequestPath); + } + + private MockResponse mockResponseFrom(MockHttpServletResponse response) { + + MockResponse mockResponse = new MockResponse(); + + for (String headerName : response.getHeaderNames()) { + String headerValue = response.getHeader(headerName); + if (StringUtils.hasText(headerValue)) { + mockResponse.addHeader(headerName, headerValue); + } + } + + mockResponse.setResponseCode(response.getStatus()); + mockResponse.setBody(getBody(response)); + + return mockResponse; + } + + private String getBody(MockHttpServletResponse response) { + + Charset responseCharacterEncoding = Charset.forName(response.getCharacterEncoding()); + + try { + return response.getContentAsString(FALLBACK_CHARSET); + } + catch (UnsupportedEncodingException e) { + throw new RuntimeException("Failed to decode content using HttpServletResponse Charset [%s]" + .formatted(responseCharacterEncoding), e); + } + } + + } + + /** + * Spring {@link FactoryBean} used to construct, configure and properly initialize the + * {@link MockWebServer} inside the Spring container. + *

+ * Unfortunately, {@link MockWebServerFactoryBean} cannot implement the Spring + * {@link SmartLifecycle} interface as originally intended. The problem is, the + * {@link MockWebServer} class is poorly designed and does not adhere to the + * {@literal Open/Closed principle}: + *

    + *
  • The class does not provide a isRunning() lifecycle method, despite the start() + * and shutdown() methods
  • + *
  • The MockWebServer.started is a private state variable
  • + *
  • The overridden before() function is protected
  • + *
  • The class is final and cannot be extended
  • + *
  • Calling MockWebServer.url(:String) needed to construct Retrofit client in the + * theoOpenAiService bean necessarily starts the MockWebServer
  • + *
+ *

+ * TODO: Figure out a way to implement the Spring {@link SmartLifecycle} interface + * without scrambling bean dependencies, bean phases, and other bean lifecycle + * methods. + * + * @see org.springframework.beans.factory.FactoryBean + * @see org.springframework.beans.factory.InitializingBean + * @see okhttp3.mockwebserver.MockWebServer + */ + static class MockWebServerFactoryBean implements FactoryBean, InitializingBean, DisposableBean { + + private Dispatcher dispatcher; + + private final Logger logger = LoggerFactory.getLogger(getClass().getName()); + + private MockWebServer mockWebServer; + + private final Queue queuedResponses = new ConcurrentLinkedDeque<>(); + + public void setDispatcher(@Nullable Dispatcher dispatcher) { + this.dispatcher = dispatcher; + } + + protected Optional getDispatcher() { + return Optional.ofNullable(this.dispatcher); + } + + protected Logger getLogger() { + return this.logger; + } + + @Override + public MockWebServer getObject() { + return start(this.mockWebServer); + } + + @Override + public Class getObjectType() { + return MockWebServer.class; + } + + @Override + public void afterPropertiesSet() { + this.mockWebServer = new MockWebServer(); + this.queuedResponses.forEach(this.mockWebServer::enqueue); + getDispatcher().ifPresent(this.mockWebServer::setDispatcher); + } + + public MockWebServerFactoryBean enqueue(MockResponse response) { + Assert.notNull(response, "MockResponse must not be null"); + this.queuedResponses.add(response); + return this; + } + + @Override + public void destroy() { + + try { + this.mockWebServer.shutdown(); + } + catch (IOException e) { + getLogger().warn("MockWebServer was not shutdown correctly: {}", e.getMessage()); + getLogger().trace("MockWebServer shutdown failure", e); + } + } + + private MockWebServer start(MockWebServer webServer) { + + try { + webServer.start(); + return webServer; + } + catch (IOException e) { + throw new IllegalStateException("Failed to start MockWebServer", e); + } + } + + } + +} diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/OpenAiClientWithGenerationMetadataTests.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/OpenAiClientWithGenerationMetadataTests.java new file mode 100644 index 00000000000..9a5e1332ebd --- /dev/null +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/OpenAiClientWithGenerationMetadataTests.java @@ -0,0 +1,184 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.client; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.client.AiResponse; +import org.springframework.ai.metadata.ChoiceMetadata; +import org.springframework.ai.metadata.GenerationMetadata; +import org.springframework.ai.metadata.PromptMetadata; +import org.springframework.ai.metadata.RateLimit; +import org.springframework.ai.metadata.Usage; +import org.springframework.ai.openai.OpenAiMockTestConfiguration; +import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; +import org.springframework.ai.prompt.Prompt; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.context.request.WebRequest; + +/** + * Tests using the {@link OpenAiClient} to send an {@literal OpenAI} API request (chat + * completion) to test the presence of {@link GenerationMetadata} in the + * {@link AiResponse}. + * + * @author John Blum + * @since 0.7.0 + */ +@SpringBootTest +@ContextConfiguration(classes = OpenAiClientWithGenerationMetadataTests.TestConfiguration.class) +@ActiveProfiles("spring-ai-openai-mocks") +@SuppressWarnings("unused") +class OpenAiClientWithGenerationMetadataTests { + + @Autowired + private OpenAiClient aiClient; + + @Test + void aiResponseContainsAiMetadata() { + + Prompt prompt = new Prompt("Reach for the sky."); + + AiResponse response = this.aiClient.generate(prompt); + + assertThat(response).isNotNull(); + + GenerationMetadata generationMetadata = response.getGenerationMetadata(); + + assertThat(generationMetadata).isNotNull(); + + Usage usage = generationMetadata.getUsage(); + + assertThat(usage).isNotNull(); + assertThat(usage.getPromptTokens()).isEqualTo(9L); + assertThat(usage.getGenerationTokens()).isEqualTo(12L); + assertThat(usage.getTotalTokens()).isEqualTo(21L); + + RateLimit rateLimit = generationMetadata.getRateLimit(); + + Duration expectedRequestsReset = Duration.ofDays(2L) + .plus(Duration.ofHours(16L)) + .plus(Duration.ofMinutes(15)) + .plus(Duration.ofSeconds(29L)); + + Duration expectedTokensReset = Duration.ofHours(27L) + .plus(Duration.ofSeconds(55L)) + .plus(Duration.ofMillis(451L)); + + assertThat(rateLimit).isNotNull(); + assertThat(rateLimit.getRequestsLimit()).isEqualTo(4000L); + assertThat(rateLimit.getRequestsRemaining()).isEqualTo(999); + assertThat(rateLimit.getRequestsReset()).isEqualTo(expectedRequestsReset); + assertThat(rateLimit.getTokensLimit()).isEqualTo(725_000L); + assertThat(rateLimit.getTokensRemaining()).isEqualTo(112_358L); + assertThat(rateLimit.getTokensReset()).isEqualTo(expectedTokensReset); + + PromptMetadata promptMetadata = response.getPromptMetadata(); + + assertThat(promptMetadata).isNotNull(); + assertThat(promptMetadata).isEmpty(); + + response.getGenerations().forEach(generation -> { + ChoiceMetadata choiceMetadata = generation.getChoiceMetadata(); + assertThat(choiceMetadata).isNotNull(); + assertThat(choiceMetadata.getFinishReason()).isEqualTo("stop"); + assertThat(choiceMetadata.getContentFilterMetadata()).isNull(); + }); + } + + @SpringBootConfiguration + @Import(OpenAiMockTestConfiguration.class) + static class TestConfiguration { + + @Bean + MockMvc mockMvc() { + return MockMvcBuilders.standaloneSetup(new SpringOpenAiChatCompletionsController()).build(); + } + + } + + @RestController + @RequestMapping("/spring-ai/api") + @SuppressWarnings("all") + static class SpringOpenAiChatCompletionsController { + + @PostMapping("/v1/chat/completions") + ResponseEntity chatCompletions(WebRequest request) { + + String json = getJson(); + + ResponseEntity response = ResponseEntity.status(HttpStatusCode.valueOf(200)) + .contentType(MediaType.APPLICATION_JSON) + .contentLength(json.getBytes(StandardCharsets.UTF_8).length) + .headers(httpHeaders -> { + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_LIMIT_HEADER.getName(), "4000"); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_REMAINING_HEADER.getName(), "999"); + httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_RESET_HEADER.getName(), "2d16h15m29s"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_LIMIT_HEADER.getName(), "725000"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); + httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); + }) + .body(getJson()); + + return response; + } + + private String getJson() { + return """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0613", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "I surrender!" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + """; + } + + } + +} diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/support/OpenAiHttpResponseHeadersInterceptorTests.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/support/OpenAiHttpResponseHeadersInterceptorTests.java new file mode 100644 index 00000000000..3ef0df4ef16 --- /dev/null +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/support/OpenAiHttpResponseHeadersInterceptorTests.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.metadata.support; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.openai.metadata.support.OpenAiHttpResponseHeadersInterceptor.DurationFormatter; + +/** + * Unit Tests for {@link OpenAiHttpResponseHeadersInterceptor}. + * + * @author John Blum + * @since 0.7.0 + */ +public class OpenAiHttpResponseHeadersInterceptorTests { + + @Test + void parseTimeAsDurationWithDaysHoursMinutesSeconds() { + + Duration actual = DurationFormatter.TIME_UNIT.parse("6d18h22m45s"); + Duration expected = Duration.ofDays(6L) + .plus(Duration.ofHours(18L)) + .plus(Duration.ofMinutes(22)) + .plus(Duration.ofSeconds(45L)); + + assertThat(actual).isEqualTo(expected); + } + + @Test + void parseTimeAsDurationWithMinutesSecondsMicrosecondsAndMicroseconds() { + + Duration actual = DurationFormatter.TIME_UNIT.parse("42m18s451ms21541ns"); + Duration expected = Duration.ofMinutes(42L) + .plus(Duration.ofSeconds(18L)) + .plus(Duration.ofMillis(451)) + .plus(Duration.ofNanos(21541L)); + + assertThat(actual).isEqualTo(expected); + } + + @Test + void parseTimeAsDurationWithDaysMinutesAndMilliseconds() { + + Duration actual = DurationFormatter.TIME_UNIT.parse("2d15m981ms"); + Duration expected = Duration.ofDays(2L).plus(Duration.ofMinutes(15L)).plus(Duration.ofMillis(981L)); + + assertThat(actual).isEqualTo(expected); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 8f1190630a0..0da17ad05e2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -16,19 +16,18 @@ package org.springframework.ai.autoconfigure.openai; +import static org.springframework.ai.autoconfigure.openai.OpenAiProperties.CONFIG_PREFIX; + import java.time.Duration; import com.fasterxml.jackson.databind.ObjectMapper; import com.theokanning.openai.client.OpenAiApi; import com.theokanning.openai.service.OpenAiService; -import okhttp3.OkHttpClient; -import retrofit2.Retrofit; -import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; -import retrofit2.converter.jackson.JacksonConverterFactory; import org.springframework.ai.autoconfigure.NativeHints; import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.openai.client.OpenAiClient; +import org.springframework.ai.openai.metadata.support.OpenAiHttpResponseHeadersInterceptor; import org.springframework.ai.openai.embedding.OpenAiEmbeddingClient; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; @@ -38,7 +37,11 @@ import org.springframework.context.annotation.ImportRuntimeHints; import org.springframework.util.StringUtils; -import static org.springframework.ai.autoconfigure.openai.OpenAiProperties.CONFIG_PREFIX; +import okhttp3.OkHttpClient; +import retrofit2.Retrofit; +import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory; +import retrofit2.converter.jackson.JacksonConverterFactory; +import retrofit2.http.HEAD; @AutoConfiguration @ConditionalOnClass(OpenAiService.class) @@ -50,8 +53,8 @@ public class OpenAiAutoConfiguration { @ConditionalOnMissingBean public OpenAiClient openAiClient(OpenAiProperties openAiProperties) { - OpenAiService openAiService = theoOpenAiService(openAiProperties.getBaseUrl(), openAiProperties.getApiKey(), - openAiProperties.getDuration()); + OpenAiService openAiService = theoOpenAiService(openAiProperties, openAiProperties.getBaseUrl(), + openAiProperties.getApiKey(), openAiProperties.getDuration()); OpenAiClient openAiClient = new OpenAiClient(openAiService); openAiClient.setTemperature(openAiProperties.getTemperature()); @@ -64,13 +67,14 @@ public OpenAiClient openAiClient(OpenAiProperties openAiProperties) { @ConditionalOnMissingBean public EmbeddingClient openAiEmbeddingClient(OpenAiProperties openAiProperties) { - OpenAiService openAiService = theoOpenAiService(openAiProperties.getEmbedding().getBaseUrl(), + OpenAiService openAiService = theoOpenAiService(openAiProperties, openAiProperties.getEmbedding().getBaseUrl(), openAiProperties.getEmbedding().getApiKey(), openAiProperties.getDuration()); return new OpenAiEmbeddingClient(openAiService, openAiProperties.getEmbedding().getModel()); } - private OpenAiService theoOpenAiService(String baseUrl, String apiKey, Duration duration) { + private OpenAiService theoOpenAiService(OpenAiProperties properties, String baseUrl, String apiKey, + Duration duration) { if ("https://api.openai.com".equals(baseUrl) && !StringUtils.hasText(apiKey)) { throw new IllegalArgumentException( @@ -78,7 +82,14 @@ private OpenAiService theoOpenAiService(String baseUrl, String apiKey, Duration } ObjectMapper mapper = OpenAiService.defaultObjectMapper(); - OkHttpClient client = OpenAiService.defaultClient(apiKey, duration); + + OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder(OpenAiService.defaultClient(apiKey, duration)); + + if (properties.getMetadata().isRateLimitMetricsEnabled()) { + clientBuilder.addInterceptor(new OpenAiHttpResponseHeadersInterceptor()); + } + + OkHttpClient client = clientBuilder.build(); // Waiting for https://github.com/TheoKanning/openai-java/issues/249 to be // resolved. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiProperties.java index 7cee1c0114c..245c42a1288 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiProperties.java @@ -35,6 +35,8 @@ public class OpenAiProperties { private final Embedding embedding = new Embedding(this); + private final Metadata metadata = new Metadata(); + private String apiKey; private String model = "gpt-3.5-turbo"; @@ -85,6 +87,10 @@ public Embedding getEmbedding() { return this.embedding; } + public Metadata getMetadata() { + return this.metadata; + } + public static class Embedding { private final OpenAiProperties openAiProperties; @@ -130,4 +136,22 @@ public void setBaseUrl(String baseUrl) { } + public static class Metadata { + + private Boolean rateLimitMetricsEnabled; + + public boolean isRateLimitMetricsEnabled() { + return Boolean.TRUE.equals(getRateLimitMetricsEnabled()); + } + + public Boolean getRateLimitMetricsEnabled() { + return this.rateLimitMetricsEnabled; + } + + public void setRateLimitMetricsEnabled(Boolean rateLimitMetricsEnabled) { + this.rateLimitMetricsEnabled = rateLimitMetricsEnabled; + } + + } + }