-
Notifications
You must be signed in to change notification settings - Fork 2k
[GH-3723] Vertex AI Gemini logprobs support #3724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,6 +64,20 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions { | |
| */ | ||
| private @JsonProperty("temperature") Double temperature; | ||
|
|
||
| /** | ||
| * Optional. Enable returning the log probabilities of the top candidate tokens at each generation step. | ||
| * The model's chosen token might not be the same as the top candidate token at each step. | ||
| * Specify the number of candidates to return by using an integer value in the range of 1-20. | ||
| * Should not be set unless responseLogprobs is set to true. | ||
| */ | ||
| private @JsonProperty("logprobs") Integer logprobs; | ||
|
|
||
| /** | ||
| * Optional. If true, returns the log probabilities of the tokens that were chosen by the model at each step. | ||
| * By default, this parameter is set to false. | ||
| */ | ||
| private @JsonProperty("responseLogprobs") boolean responseLogprobs; | ||
|
|
||
| /** | ||
| * Optional. If specified, nucleus sampling will be used. | ||
| */ | ||
|
|
@@ -162,6 +176,8 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr | |
| options.setSafetySettings(fromOptions.getSafetySettings()); | ||
| options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()); | ||
| options.setToolContext(fromOptions.getToolContext()); | ||
| options.setLogprobs(fromOptions.getLogprobs()); | ||
| options.setResponseLogprobs(fromOptions.getResponseLogprobs()); | ||
| return options; | ||
| } | ||
|
|
||
|
|
@@ -183,6 +199,10 @@ public void setTemperature(Double temperature) { | |
| this.temperature = temperature; | ||
| } | ||
|
|
||
| public void setResponseLogprobs(boolean responseLogprobs) { | ||
| this.responseLogprobs = responseLogprobs; | ||
| } | ||
|
|
||
| @Override | ||
| public Double getTopP() { | ||
| return this.topP; | ||
|
|
@@ -326,6 +346,18 @@ public void setToolContext(Map<String, Object> toolContext) { | |
| this.toolContext = toolContext; | ||
| } | ||
|
|
||
| public Integer getLogprobs() { | ||
| return logprobs; | ||
| } | ||
|
|
||
| public void setLogprobs(Integer logprobs) { | ||
| this.logprobs = logprobs; | ||
| } | ||
|
|
||
| public boolean getResponseLogprobs() { | ||
| return responseLogprobs; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object o) { | ||
| if (this == o) { | ||
|
|
@@ -346,15 +378,17 @@ public boolean equals(Object o) { | |
| && Objects.equals(this.toolNames, that.toolNames) | ||
| && Objects.equals(this.safetySettings, that.safetySettings) | ||
| && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) | ||
| && Objects.equals(this.toolContext, that.toolContext); | ||
| && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.logprobs, that.logprobs) | ||
| && Objects.equals(this.responseLogprobs, that.responseLogprobs); | ||
| } | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, | ||
| this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType, | ||
| this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings, | ||
| this.internalToolExecutionEnabled, this.toolContext); | ||
| this.internalToolExecutionEnabled, this.toolContext, this.toolContext, this.logprobs, | ||
|
||
| this.responseLogprobs); | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -365,7 +399,8 @@ public String toString() { | |
| + this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\'' | ||
| + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks | ||
| + ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval | ||
| + ", safetySettings=" + this.safetySettings + '}'; | ||
| + ", safetySettings=" + this.safetySettings + ", logProbs=" + this.logprobs + ", responseLogprobs=" | ||
| + this.responseLogprobs + '}'; | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -488,6 +523,16 @@ public Builder toolContext(Map<String, Object> toolContext) { | |
| return this; | ||
| } | ||
|
|
||
| public Builder logprobs(Integer logprobs) { | ||
| this.options.setLogprobs(logprobs); | ||
| return this; | ||
| } | ||
|
|
||
| public Builder responseLogprobs(Boolean responseLogprobs) { | ||
| this.options.setResponseLogprobs(responseLogprobs); | ||
| return this; | ||
| } | ||
|
|
||
| public VertexAiGeminiChatOptions build() { | ||
| return this.options; | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| package org.springframework.ai.vertexai.gemini.api; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a license header. |
||
|
|
||
| import java.util.List; | ||
|
|
||
| public class VertexAiGeminiApi { | ||
|
|
||
| public record LogProbs(Double avgLogprobs, List<TopContent> topCandidates, | ||
| List<LogProbs.Content> chosenCandidates) { | ||
| public record Content(String token, Float logprob, Integer id) { | ||
| } | ||
|
|
||
| public record TopContent(List<Content> candidates) { | ||
| } | ||
| } | ||
|
|
||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move this line between lines 87 and 88.