diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java index 776cba66d58..37b82a0fcee 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java @@ -53,6 +53,9 @@ public class BedrockChatOptions implements ToolCallingChatOptions { @JsonProperty("presencePenalty") private Double presencePenalty; + @JsonIgnore + private Map requestParameters = new HashMap<>(); + @JsonProperty("stopSequences") private List stopSequences; @@ -87,6 +90,7 @@ public static BedrockChatOptions fromOptions(BedrockChatOptions fromOptions) { .frequencyPenalty(fromOptions.getFrequencyPenalty()) .maxTokens(fromOptions.getMaxTokens()) .presencePenalty(fromOptions.getPresencePenalty()) + .requestParameters(new HashMap<>(fromOptions.getRequestParameters())) .stopSequences( fromOptions.getStopSequences() != null ? new ArrayList<>(fromOptions.getStopSequences()) : null) .temperature(fromOptions.getTemperature()) @@ -126,6 +130,12 @@ public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } + public Map getRequestParameters() { return this.requestParameters; } + + public void setRequestParameters(Map requestParameters) { + this.requestParameters = requestParameters; + } + @Override public Double getPresencePenalty() { return this.presencePenalty; @@ -241,6 +251,7 @@ public boolean equals(Object o) { return Objects.equals(this.model, that.model) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) && Objects.equals(this.maxTokens, that.maxTokens) && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.requestParameters, that.requestParameters) && Objects.equals(this.stopSequences, that.stopSequences) && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topK, that.topK) && Objects.equals(this.topP, that.topP) && Objects.equals(this.toolCallbacks, that.toolCallbacks) @@ -250,8 +261,9 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(this.model, this.frequencyPenalty, this.maxTokens, this.presencePenalty, this.stopSequences, - this.temperature, this.topK, this.topP, this.toolCallbacks, this.toolNames, this.toolContext, + return Objects.hash(this.model, this.frequencyPenalty, this.maxTokens, this.presencePenalty, + this.requestParameters, this.stopSequences, this.temperature, this.topK, this.topP, + this.toolCallbacks, this.toolNames, this.toolContext, this.internalToolExecutionEnabled); } @@ -279,6 +291,11 @@ public Builder presencePenalty(Double presencePenalty) { return this; } + public Builder requestParameters(Map requestParameters) { + this.options.requestParameters = requestParameters; + return this; + } + public Builder stopSequences(List stopSequences) { this.options.stopSequences = stopSequences; return this; diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 071e77a78cb..a089676dda3 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -425,6 +425,8 @@ else if (message.getMessageType() == MessageType.TOOL) { Document additionalModelRequestFields = ConverseApiUtils .getChatOptionsAdditionalModelRequestFields(this.defaultOptions, prompt.getOptions()); + Map requestMetadata = ConverseApiUtils.getRequestMetadata(prompt.getUserMessage().getMetadata()); + return ConverseRequest.builder() .modelId(updatedRuntimeOptions.getModel()) .inferenceConfig(inferenceConfiguration) @@ -432,6 +434,7 @@ else if (message.getMessageType() == MessageType.TOOL) { .system(systemMessages) .additionalModelRequestFields(additionalModelRequestFields) .toolConfig(toolConfiguration) + .requestMetadata(requestMetadata) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java index d58fdbad8cf..dd49bb5e006 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java @@ -384,6 +384,26 @@ else if (value instanceof Map mapValue) { } } + @SuppressWarnings("unchecked") + public static Map getRequestMetadata(Map metadata) { + + if (metadata.isEmpty()) { + return Map.of(); + } + + Map result = new HashMap<>(); + for (Map.Entry entry : metadata.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + + if (key != null && value != null) { + result.put(key, value.toString()); + } + } + + return result; + } + private static Document convertMapToDocument(Map value) { Map attr = value.entrySet() .stream() diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockChatOptionsTests.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockChatOptionsTests.java index aed48c1a3b5..e34ba9a84a0 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockChatOptionsTests.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockChatOptionsTests.java @@ -37,6 +37,7 @@ void testBuilderWithAllFields() { .frequencyPenalty(0.0) .maxTokens(100) .presencePenalty(0.0) + .requestParameters(Map.of("requestId", "1234")) .stopSequences(List.of("stop1", "stop2")) .temperature(0.7) .topP(0.8) @@ -44,9 +45,9 @@ void testBuilderWithAllFields() { .build(); assertThat(options) - .extracting("model", "frequencyPenalty", "maxTokens", "presencePenalty", "stopSequences", "temperature", - "topP", "topK") - .containsExactly("test-model", 0.0, 100, 0.0, List.of("stop1", "stop2"), 0.7, 0.8, 50); + .extracting("model", "frequencyPenalty", "maxTokens", "presencePenalty", "requestParameters", + "stopSequences", "temperature", "topP", "topK") + .containsExactly("test-model", 0.0, 100, 0.0, Map.of("requestId", "1234"), List.of("stop1", "stop2"), 0.7, 0.8, 50); } @Test diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index 6980b6b2859..520337e0fd8 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -70,7 +70,9 @@ void call() { .system(s -> s.text(this.systemTextResource) .param("name", "Bob") .param("voice", "pirate")) - .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") + .user(u -> u.text("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") + .param("requestId", "1234") + ) .call() .chatResponse(); // @formatter:on