Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
"role", choice.message().role() != null ? choice.message().role().name() : "",
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
"annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of());
return buildGeneration(choice, metadata, request);
}).toList();
// @formatter:on
Expand Down Expand Up @@ -316,8 +317,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
"role", roleMap.getOrDefault(id, ""),
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");

"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
"annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of());
return buildGeneration(choice, metadata, request);
}).toList();
// @formatter:on
Expand Down Expand Up @@ -580,7 +581,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {

}
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput));
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null));
}
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
Expand All @@ -590,7 +591,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
return toolMessage.getResponses()
.stream()
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
tr.id(), null, null, null))
tr.id(), null, null, null, null))
.toList();
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.WebSearchOptions;
import org.springframework.ai.openai.api.ResponseFormat;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -194,6 +195,11 @@ public class OpenAiChatOptions implements ToolCallingChatOptions {
*/
private @JsonProperty("reasoning_effort") String reasoningEffort;

/**
* This tool searches the web for relevant results to use in a response.
*/
private @JsonProperty("web_search_options") WebSearchOptions webSearchOptions;

/**
* Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests.
*/
Expand Down Expand Up @@ -593,6 +599,14 @@ public void setReasoningEffort(String reasoningEffort) {
this.reasoningEffort = reasoningEffort;
}

public WebSearchOptions getWebSearchOptions() {
return this.webSearchOptions;
}

public void setWebSearchOptions(WebSearchOptions webSearchOptions) {
this.webSearchOptions = webSearchOptions;
}

@Override
public OpenAiChatOptions copy() {
return OpenAiChatOptions.fromOptions(this);
Expand All @@ -605,7 +619,7 @@ public int hashCode() {
this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders,
this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio,
this.store, this.metadata, this.reasoningEffort);
this.store, this.metadata, this.reasoningEffort, this.webSearchOptions);
}

@Override
Expand Down Expand Up @@ -637,7 +651,8 @@ public boolean equals(Object o) {
&& Objects.equals(this.outputModalities, other.outputModalities)
&& Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store)
&& Objects.equals(this.metadata, other.metadata)
&& Objects.equals(this.reasoningEffort, other.reasoningEffort);
&& Objects.equals(this.reasoningEffort, other.reasoningEffort)
&& Objects.equals(this.webSearchOptions, other.webSearchOptions);
}

@Override
Expand Down Expand Up @@ -848,6 +863,11 @@ public Builder reasoningEffort(String reasoningEffort) {
return this;
}

public Builder webSearchOptions(WebSearchOptions webSearchOptions) {
this.options.webSearchOptions = webSearchOptions;
return this;
}

public OpenAiChatOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,21 @@ public enum ChatModel implements ChatModelDescription {
* Context window: 4,096 tokens. Max output tokens: 4,096 tokens. Knowledge
* cutoff: September, 2021.
*/
GPT_3_5_TURBO_INSTRUCT("gpt-3.5-turbo-instruct");
GPT_3_5_TURBO_INSTRUCT("gpt-3.5-turbo-instruct"),

/**
* <b>GPT-4o Search Preview</b> is a specialized model for web search in Chat
* Completions. It is trained to understand and execute web search queries. See
* the web search guide for more information.
*/
GPT_4_O_SEARCH_PREVIEW("gpt-4o-search-preview"),

/**
* <b>GPT-4o mini Search Preview</b> is a specialized model for web search in Chat
* Completions. It is trained to understand and execute web search queries. See
* the web search guide for more information.
*/
GPT_4_O_MINI_SEARCH_PREVIEW("gpt-4o-mini-search-preview");

public final String value;

Expand Down Expand Up @@ -835,6 +849,10 @@ public enum OutputModality {
* @param parallelToolCalls If set to true, the model will call all functions in the
* tools list in parallel. Otherwise, the model will call the functions in the tools
* list in the order they are provided.
* @param reasoningEffort Constrains effort on reasoning for reasoning models.
* Currently supported values are low, medium, and high. Reducing reasoning effort can
* result in faster responses and fewer tokens used on reasoning in a response.
* @param webSearchOptions Options for web search.
*/
@JsonInclude(Include.NON_NULL)
public record ChatCompletionRequest(// @formatter:off
Expand Down Expand Up @@ -864,7 +882,8 @@ public record ChatCompletionRequest(// @formatter:off
@JsonProperty("tool_choice") Object toolChoice,
@JsonProperty("parallel_tool_calls") Boolean parallelToolCalls,
@JsonProperty("user") String user,
@JsonProperty("reasoning_effort") String reasoningEffort) {
@JsonProperty("reasoning_effort") String reasoningEffort,
@JsonProperty("web_search_options") WebSearchOptions webSearchOptions) {

/**
* Shortcut constructor for a chat completion request with the given messages, model and temperature.
Expand All @@ -876,7 +895,7 @@ public record ChatCompletionRequest(// @formatter:off
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null,
null, null, null, false, null, temperature, null,
null, null, null, null, null);
null, null, null, null, null, null);
}

/**
Expand All @@ -890,7 +909,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
this(messages, model, null, null, null, null, null, null,
null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null,
null, null, null, stream, null, null, null,
null, null, null, null, null);
null, null, null, null, null, null);
}

/**
Expand All @@ -905,7 +924,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) {
this(messages, model, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, stream, null, temperature, null,
null, null, null, null, null);
null, null, null, null, null, null);
}

/**
Expand All @@ -921,7 +940,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
List<FunctionTool> tools, Object toolChoice) {
this(messages, model, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, false, null, 0.8, null,
tools, toolChoice, null, null, null);
tools, toolChoice, null, null, null, null);
}

/**
Expand All @@ -934,7 +953,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
this(messages, null, null, null, null, null, null, null, null, null, null,
null, null, null, null, null, null, null, stream, null, null, null,
null, null, null, null, null);
null, null, null, null, null, null);
}

/**
Expand All @@ -947,7 +966,7 @@ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) {
return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs,
this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty,
this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP,
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort);
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions);
}

/**
Expand Down Expand Up @@ -1029,6 +1048,61 @@ public record StreamOptions(

public static StreamOptions INCLUDE_USAGE = new StreamOptions(true);
}

/**
* This tool searches the web for relevant results to use in a response.
*
* @param searchContextSize
* @param userLocation
*/
@JsonInclude(Include.NON_NULL)
public record WebSearchOptions(@JsonProperty("search_context_size") SearchContextSize searchContextSize,
@JsonProperty("user_location") UserLocation userLocation) {

/**
* High level guidance for the amount of context window space to use for the
* search. One of low, medium, or high. medium is the default.
*/
public enum SearchContextSize {

/**
* Low context size.
*/
@JsonProperty("low")
LOW,

/**
* Medium context size. This is the default.
*/
@JsonProperty("medium")
MEDIUM,

/**
* High context size.
*/
@JsonProperty("high")
HIGH

}

/**
* Approximate location parameters for the search.
*
* @param type The type of location approximation. Always "approximate".
* @param approximate The approximate location details.
*/
@JsonInclude(Include.NON_NULL)
public record UserLocation(@JsonProperty("type") String type,
@JsonProperty("approximate") Approximate approximate) {

@JsonInclude(Include.NON_NULL)
public record Approximate(@JsonProperty("city") String city, @JsonProperty("country") String country,
@JsonProperty("region") String region, @JsonProperty("timezone") String timezone) {
}
}

}

} // @formatter:on

/**
Expand All @@ -1047,19 +1121,22 @@ public record StreamOptions(
* Applicable only for {@link Role#ASSISTANT} role and null otherwise.
* @param refusal The refusal message by the assistant. Applicable only for
* {@link Role#ASSISTANT} role and null otherwise.
* @param audioOutput Audio response from the model. >>>>>>> bdb66e577 (OpenAI -
* Support audio input modality)
* @param audioOutput Audio response from the model.
* @param annotations Annotations for the message, when applicable, as when using the
* web search tool.
*/
@JsonInclude(Include.NON_NULL)
public record ChatCompletionMessage(// @formatter:off
@JsonInclude(JsonInclude.Include.NON_NULL)
public record ChatCompletionMessage(
// @formatter:off
@JsonProperty("content") Object rawContent,
@JsonProperty("role") Role role,
@JsonProperty("name") String name,
@JsonProperty("tool_call_id") String toolCallId,
@JsonProperty("tool_calls")
@JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
@JsonProperty("tool_calls") @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
@JsonProperty("refusal") String refusal,
@JsonProperty("audio") AudioOutput audioOutput) { // @formatter:on
@JsonProperty("audio") AudioOutput audioOutput,
@JsonProperty("annotations") List<Annotation> annotations
) { // @formatter:on

/**
* Create a chat completion message with the given content and role. All other
Expand All @@ -1068,8 +1145,7 @@ public record ChatCompletionMessage(// @formatter:off
* @param role The role of the author of this message.
*/
public ChatCompletionMessage(Object content, Role role) {
this(content, role, null, null, null, null, null);

this(content, role, null, null, null, null, null, null);
}

/**
Expand Down Expand Up @@ -1246,6 +1322,29 @@ public record AudioOutput(// @formatter:off
@JsonProperty("transcript") String transcript
) { // @formatter:on
}

/**
* Represents an annotation within a message, specifically for URL citations.
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public record Annotation(@JsonProperty("type") String type,
@JsonProperty("url_citation") UrlCitation urlCitation) {
/**
* A URL citation when using web search.
*
* @param endIndex The index of the last character of the URL citation in the
* message.
* @param startIndex The index of the first character of the URL citation in
* the message.
* @param title The title of the web resource.
* @param url The URL of the web resource.
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public record UrlCitation(@JsonProperty("end_index") Integer endIndex,
@JsonProperty("start_index") Integer startIndex, @JsonProperty("title") String title,
@JsonProperty("url") String url) {
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
* @since 0.8.1
*/
public class OpenAiStreamFunctionCallingHelper {
Expand Down Expand Up @@ -98,6 +99,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
String refusal = (current.refusal() != null ? current.refusal() : previous.refusal());
ChatCompletionMessage.AudioOutput audioOutput = (current.audioOutput() != null ? current.audioOutput()
: previous.audioOutput());
List<ChatCompletionMessage.Annotation> annotations = (current.annotations() != null ? current.annotations()
: previous.annotations());

List<ToolCall> toolCalls = new ArrayList<>();
ToolCall lastPreviousTooCall = null;
Expand Down Expand Up @@ -127,7 +130,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
toolCalls.add(lastPreviousTooCall);
}
}
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput);
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput, annotations);
}

private ToolCall merge(ToolCall previous, ToolCall current) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void validateReasoningTokens() {
"If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER);
ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null, null, null, null,
null, null, null, null, null, null, null, null, null, null, null, null, false, null, null, null, null,
null, null, null, "low");
null, null, null, "low", null);
ResponseEntity<ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);

assertThat(response).isNotNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
*/
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiApiToolFunctionCallIT {
Expand Down Expand Up @@ -129,7 +130,7 @@ public void toolFunctionCall() {

// extend conversation with function response.
messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL,
functionName, toolCall.id(), null, null, null));
functionName, toolCall.id(), null, null, null, null));
}
}

Expand Down
Loading