From fb272400e74a6da2a4e597f160edbf5d3123aaae Mon Sep 17 00:00:00 2001 From: Alexandros Pappas Date: Thu, 13 Mar 2025 15:27:28 +0100 Subject: [PATCH] feat: OpenAI Web Search Annotations This PR adds support for retrieving web search annotations from the OpenAI API, as described in their [web search documentation](https://platform.openai.com/docs/guides/web-search). This allows us to access citation URLs and their context within generated responses when using models like `gpt-4o-search-preview`. **Changes:** * Added `annotations` (with `Annotation` and `UrlCitation` records) to `ChatCompletionMessage` in `OpenAiApi.java`. * Updated `OpenAiChatModel` to populate the `annotations` field (via metadata) for both regular and streaming responses. * Added integration tests (`webSearchAnnotationsTest`, `streamWebSearchAnnotationsTest`) to `OpenAiChatModelIT.java`. * Added `GPT_4_O_SEARCH_PREVIEW` and `GPT_4_O_MINI_SEARCH_PREVIEW` to `OpenAiApi.ChatModel`. * Added `WebSearchOptions` and related records to `OpenAiApi`. * Minor updates to `ChatCompletionRequest` and its `Builder`. Resolves #2449 Signed-off-by: Alexandros Pappas --- .../ai/openai/OpenAiChatModel.java | 11 +- .../ai/openai/OpenAiChatOptions.java | 24 +++- .../ai/openai/api/OpenAiApi.java | 133 +++++++++++++++--- .../OpenAiStreamFunctionCallingHelper.java | 5 +- .../ai/openai/api/OpenAiApiIT.java | 2 +- .../api/tool/OpenAiApiToolFunctionCallIT.java | 3 +- .../ai/openai/chat/OpenAiChatModelIT.java | 66 +++++++++ 7 files changed, 217 insertions(+), 27 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index d6ea170e979..078ad718a8f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -218,7 +218,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons "role", choice.message().role() != null ? choice.message().role().name() : "", "index", choice.index(), "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "", - "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : ""); + "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", + "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of()); return buildGeneration(choice, metadata, request); }).toList(); // @formatter:on @@ -316,8 +317,8 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha "role", roleMap.getOrDefault(id, ""), "index", choice.index(), "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "", - "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : ""); - + "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", + "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of()); return buildGeneration(choice, metadata, request); }).toList(); // @formatter:on @@ -580,7 +581,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { } return List.of(new ChatCompletionMessage(assistantMessage.getText(), - ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput)); + ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; @@ -590,7 +591,7 @@ else if (message.getMessageType() == MessageType.TOOL) { return toolMessage.getResponses() .stream() .map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(), - tr.id(), null, null, null)) + tr.id(), null, null, null, null)) .toList(); } else { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 3025d5f9056..76d35a8e768 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -37,6 +37,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.WebSearchOptions; import org.springframework.ai.openai.api.ResponseFormat; import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; @@ -194,6 +195,11 @@ public class OpenAiChatOptions implements ToolCallingChatOptions { */ private @JsonProperty("reasoning_effort") String reasoningEffort; + /** + * This tool searches the web for relevant results to use in a response. + */ + private @JsonProperty("web_search_options") WebSearchOptions webSearchOptions; + /** * Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests. */ @@ -593,6 +599,14 @@ public void setReasoningEffort(String reasoningEffort) { this.reasoningEffort = reasoningEffort; } + public WebSearchOptions getWebSearchOptions() { + return this.webSearchOptions; + } + + public void setWebSearchOptions(WebSearchOptions webSearchOptions) { + this.webSearchOptions = webSearchOptions; + } + @Override public OpenAiChatOptions copy() { return OpenAiChatOptions.fromOptions(this); @@ -605,7 +619,7 @@ public int hashCode() { this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders, this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio, - this.store, this.metadata, this.reasoningEffort); + this.store, this.metadata, this.reasoningEffort, this.webSearchOptions); } @Override @@ -637,7 +651,8 @@ public boolean equals(Object o) { && Objects.equals(this.outputModalities, other.outputModalities) && Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store) && Objects.equals(this.metadata, other.metadata) - && Objects.equals(this.reasoningEffort, other.reasoningEffort); + && Objects.equals(this.reasoningEffort, other.reasoningEffort) + && Objects.equals(this.webSearchOptions, other.webSearchOptions); } @Override @@ -848,6 +863,11 @@ public Builder reasoningEffort(String reasoningEffort) { return this; } + public Builder webSearchOptions(WebSearchOptions webSearchOptions) { + this.options.webSearchOptions = webSearchOptions; + return this; + } + public OpenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 433f59d9697..b3a3e15eb41 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -475,7 +475,21 @@ public enum ChatModel implements ChatModelDescription { * Context window: 4,096 tokens. Max output tokens: 4,096 tokens. Knowledge * cutoff: September, 2021. */ - GPT_3_5_TURBO_INSTRUCT("gpt-3.5-turbo-instruct"); + GPT_3_5_TURBO_INSTRUCT("gpt-3.5-turbo-instruct"), + + /** + * GPT-4o Search Preview is a specialized model for web search in Chat + * Completions. It is trained to understand and execute web search queries. See + * the web search guide for more information. + */ + GPT_4_O_SEARCH_PREVIEW("gpt-4o-search-preview"), + + /** + * GPT-4o mini Search Preview is a specialized model for web search in Chat + * Completions. It is trained to understand and execute web search queries. See + * the web search guide for more information. + */ + GPT_4_O_MINI_SEARCH_PREVIEW("gpt-4o-mini-search-preview"); public final String value; @@ -835,6 +849,10 @@ public enum OutputModality { * @param parallelToolCalls If set to true, the model will call all functions in the * tools list in parallel. Otherwise, the model will call the functions in the tools * list in the order they are provided. + * @param reasoningEffort Constrains effort on reasoning for reasoning models. + * Currently supported values are low, medium, and high. Reducing reasoning effort can + * result in faster responses and fewer tokens used on reasoning in a response. + * @param webSearchOptions Options for web search. */ @JsonInclude(Include.NON_NULL) public record ChatCompletionRequest(// @formatter:off @@ -864,7 +882,8 @@ public record ChatCompletionRequest(// @formatter:off @JsonProperty("tool_choice") Object toolChoice, @JsonProperty("parallel_tool_calls") Boolean parallelToolCalls, @JsonProperty("user") String user, - @JsonProperty("reasoning_effort") String reasoningEffort) { + @JsonProperty("reasoning_effort") String reasoningEffort, + @JsonProperty("web_search_options") WebSearchOptions webSearchOptions) { /** * Shortcut constructor for a chat completion request with the given messages, model and temperature. @@ -876,7 +895,7 @@ public record ChatCompletionRequest(// @formatter:off public ChatCompletionRequest(List messages, String model, Double temperature) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, temperature, null, - null, null, null, null, null); + null, null, null, null, null, null); } /** @@ -890,7 +909,7 @@ public ChatCompletionRequest(List messages, String model, this(messages, model, null, null, null, null, null, null, null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null, null, null, null, stream, null, null, null, - null, null, null, null, null); + null, null, null, null, null, null); } /** @@ -905,7 +924,7 @@ public ChatCompletionRequest(List messages, String model, public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, stream, null, temperature, null, - null, null, null, null, null); + null, null, null, null, null, null); } /** @@ -921,7 +940,7 @@ public ChatCompletionRequest(List messages, String model, List tools, Object toolChoice) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, 0.8, null, - tools, toolChoice, null, null, null); + tools, toolChoice, null, null, null, null); } /** @@ -934,7 +953,7 @@ public ChatCompletionRequest(List messages, String model, public ChatCompletionRequest(List messages, Boolean stream) { this(messages, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, stream, null, null, null, - null, null, null, null, null); + null, null, null, null, null, null); } /** @@ -947,7 +966,7 @@ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) { return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty, this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP, - this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort); + this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions); } /** @@ -1029,6 +1048,61 @@ public record StreamOptions( public static StreamOptions INCLUDE_USAGE = new StreamOptions(true); } + + /** + * This tool searches the web for relevant results to use in a response. + * + * @param searchContextSize + * @param userLocation + */ + @JsonInclude(Include.NON_NULL) + public record WebSearchOptions(@JsonProperty("search_context_size") SearchContextSize searchContextSize, + @JsonProperty("user_location") UserLocation userLocation) { + + /** + * High level guidance for the amount of context window space to use for the + * search. One of low, medium, or high. medium is the default. + */ + public enum SearchContextSize { + + /** + * Low context size. + */ + @JsonProperty("low") + LOW, + + /** + * Medium context size. This is the default. + */ + @JsonProperty("medium") + MEDIUM, + + /** + * High context size. + */ + @JsonProperty("high") + HIGH + + } + + /** + * Approximate location parameters for the search. + * + * @param type The type of location approximation. Always "approximate". + * @param approximate The approximate location details. + */ + @JsonInclude(Include.NON_NULL) + public record UserLocation(@JsonProperty("type") String type, + @JsonProperty("approximate") Approximate approximate) { + + @JsonInclude(Include.NON_NULL) + public record Approximate(@JsonProperty("city") String city, @JsonProperty("country") String country, + @JsonProperty("region") String region, @JsonProperty("timezone") String timezone) { + } + } + + } + } // @formatter:on /** @@ -1047,19 +1121,22 @@ public record StreamOptions( * Applicable only for {@link Role#ASSISTANT} role and null otherwise. * @param refusal The refusal message by the assistant. Applicable only for * {@link Role#ASSISTANT} role and null otherwise. - * @param audioOutput Audio response from the model. >>>>>>> bdb66e577 (OpenAI - - * Support audio input modality) + * @param audioOutput Audio response from the model. + * @param annotations Annotations for the message, when applicable, as when using the + * web search tool. */ - @JsonInclude(Include.NON_NULL) - public record ChatCompletionMessage(// @formatter:off + @JsonInclude(JsonInclude.Include.NON_NULL) + public record ChatCompletionMessage( + // @formatter:off @JsonProperty("content") Object rawContent, @JsonProperty("role") Role role, @JsonProperty("name") String name, @JsonProperty("tool_call_id") String toolCallId, - @JsonProperty("tool_calls") - @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List toolCalls, + @JsonProperty("tool_calls") @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List toolCalls, @JsonProperty("refusal") String refusal, - @JsonProperty("audio") AudioOutput audioOutput) { // @formatter:on + @JsonProperty("audio") AudioOutput audioOutput, + @JsonProperty("annotations") List annotations + ) { // @formatter:on /** * Create a chat completion message with the given content and role. All other @@ -1068,8 +1145,7 @@ public record ChatCompletionMessage(// @formatter:off * @param role The role of the author of this message. */ public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null, null, null); - + this(content, role, null, null, null, null, null, null); } /** @@ -1246,6 +1322,29 @@ public record AudioOutput(// @formatter:off @JsonProperty("transcript") String transcript ) { // @formatter:on } + + /** + * Represents an annotation within a message, specifically for URL citations. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record Annotation(@JsonProperty("type") String type, + @JsonProperty("url_citation") UrlCitation urlCitation) { + /** + * A URL citation when using web search. + * + * @param endIndex The index of the last character of the URL citation in the + * message. + * @param startIndex The index of the first character of the URL citation in + * the message. + * @param title The title of the web resource. + * @param url The URL of the web resource. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public record UrlCitation(@JsonProperty("end_index") Integer endIndex, + @JsonProperty("start_index") Integer startIndex, @JsonProperty("title") String title, + @JsonProperty("url") String url) { + } + } } /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java index 17be23c4b41..e159d0362c9 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java @@ -40,6 +40,7 @@ * * @author Christian Tzolov * @author Thomas Vitale + * @author Alexandros Pappas * @since 0.8.1 */ public class OpenAiStreamFunctionCallingHelper { @@ -98,6 +99,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti String refusal = (current.refusal() != null ? current.refusal() : previous.refusal()); ChatCompletionMessage.AudioOutput audioOutput = (current.audioOutput() != null ? current.audioOutput() : previous.audioOutput()); + List annotations = (current.annotations() != null ? current.annotations() + : previous.annotations()); List toolCalls = new ArrayList<>(); ToolCall lastPreviousTooCall = null; @@ -127,7 +130,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti toolCalls.add(lastPreviousTooCall); } } - return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput); + return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput, annotations); } private ToolCall merge(ToolCall previous, ToolCall current) { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java index f843c0c7338..36c1eb84a7b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java @@ -75,7 +75,7 @@ void validateReasoningTokens() { "If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER); ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, null, null, null, - null, null, null, "low"); + null, null, null, "low", null); ResponseEntity response = this.openAiApi.chatCompletionEntity(request); assertThat(response).isNotNull(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java index 482fc115792..e655de46421 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java @@ -44,6 +44,7 @@ * * @author Christian Tzolov * @author Thomas Vitale + * @author Alexandros Pappas */ @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") public class OpenAiApiToolFunctionCallIT { @@ -129,7 +130,7 @@ public void toolFunctionCall() { // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, - functionName, toolCall.id(), null, null, null)); + functionName, toolCall.id(), null, null, null, null)); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java index b8d970d8eb0..3964c244856 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelIT.java @@ -633,6 +633,72 @@ void validateStoreAndMetadata() { assertThat(response).isNotNull(); } + @Test + void webSearchAnnotationsTest() { + UserMessage userMessage = new UserMessage("What is the latest news on the Mars rover?"); + + var promptOptions = OpenAiChatOptions.builder() + .model(OpenAiApi.ChatModel.GPT_4_O_SEARCH_PREVIEW.getValue()) + .webSearchOptions(new OpenAiApi.ChatCompletionRequest.WebSearchOptions( + OpenAiApi.ChatCompletionRequest.WebSearchOptions.SearchContextSize.MEDIUM, + new OpenAiApi.ChatCompletionRequest.WebSearchOptions.UserLocation("approximate", + new OpenAiApi.ChatCompletionRequest.WebSearchOptions.UserLocation.Approximate( + "San Francisco", "US", "California", "America/Los_Angeles")))) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(List.of(userMessage), promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getText()).isNotEmpty(); + + Object annotationsRaw = response.getResult().getOutput().getMetadata().get("annotations"); + assertThat(annotationsRaw).isNotNull().isInstanceOf(List.class); + + List annotations = (List) annotationsRaw; + assertThat(annotations).isNotEmpty(); + assertThat(annotations.get(0).type()).isEqualTo("url_citation"); + assertThat(annotations.get(0).urlCitation()).isNotNull(); + assertThat(annotations.get(0).urlCitation().url()).isNotEmpty(); + } + + @Test + void streamWebSearchAnnotationsTest() { + UserMessage userMessage = new UserMessage("What is the weather in San Francisco?"); + + var promptOptions = OpenAiChatOptions.builder() + .model(OpenAiApi.ChatModel.GPT_4_O_SEARCH_PREVIEW.getValue()) + .build(); + + Flux responseFlux = this.streamingChatModel + .stream(new Prompt(List.of(userMessage), promptOptions)); + + // Collect all streamed ChatResponses into a list. + List responses = responseFlux.collectList().block(); + assert responses != null; + assertThat(responses).isNotEmpty(); + ChatResponse lastResponse = responses.get(responses.size() - 1); + logger.info("Last Response: {}", lastResponse); + + Object annotationsRaw = lastResponse.getResult().getOutput().getMetadata().get("annotations"); + assertThat(annotationsRaw).isNotNull().isInstanceOf(List.class); + + List annotations = (List) annotationsRaw; + assertThat(annotations).isNotEmpty(); + assertThat(annotations.get(0).type()).isEqualTo("url_citation"); + assertThat(annotations.get(0).urlCitation()).isNotNull(); + assertThat(annotations.get(0).urlCitation().url()).isNotEmpty(); + + // For debugging, log fullContent + String fullContent = responses.stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + logger.info("Full Content: {}", fullContent); + } + record ActorsFilmsRecord(String actor, List movies) { }