Skip to content

Commit 29bd6d0

Browse files
Stuart LoxtonStuartLox
authored andcommitted
[Enhancement] - Add requestMetadata to converse request
Signed-off-by: Stuart Loxton <[email protected]> Signed-off-by: Stuart Loxton <[email protected]>
1 parent c7f7b68 commit 29bd6d0

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ public class BedrockChatOptions implements ToolCallingChatOptions {
5353
@JsonProperty("presencePenalty")
5454
private Double presencePenalty;
5555

56+
@JsonIgnore
57+
private Map<String, String> requestParameters = new HashMap<>();
58+
5659
@JsonProperty("stopSequences")
5760
private List<String> stopSequences;
5861

@@ -87,6 +90,7 @@ public static BedrockChatOptions fromOptions(BedrockChatOptions fromOptions) {
8790
.frequencyPenalty(fromOptions.getFrequencyPenalty())
8891
.maxTokens(fromOptions.getMaxTokens())
8992
.presencePenalty(fromOptions.getPresencePenalty())
93+
.requestParameters(new HashMap<>(fromOptions.getRequestParameters()))
9094
.stopSequences(
9195
fromOptions.getStopSequences() != null ? new ArrayList<>(fromOptions.getStopSequences()) : null)
9296
.temperature(fromOptions.getTemperature())
@@ -126,6 +130,12 @@ public void setMaxTokens(Integer maxTokens) {
126130
this.maxTokens = maxTokens;
127131
}
128132

133+
public Map<String, String> getRequestParameters() { return this.requestParameters; }
134+
135+
public void setRequestParameters(Map<String, String> requestParameters) {
136+
this.requestParameters = requestParameters;
137+
}
138+
129139
@Override
130140
public Double getPresencePenalty() {
131141
return this.presencePenalty;
@@ -279,6 +289,11 @@ public Builder presencePenalty(Double presencePenalty) {
279289
return this;
280290
}
281291

292+
public Builder requestParameters(Map<String, String> requestParameters) {
293+
this.options.requestParameters = requestParameters;
294+
return this;
295+
}
296+
282297
public Builder stopSequences(List<String> stopSequences) {
283298
this.options.stopSequences = stopSequences;
284299
return this;

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@
2222
import java.net.URL;
2323
import java.net.URLConnection;
2424
import java.time.Duration;
25-
import java.util.ArrayList;
26-
import java.util.Base64;
27-
import java.util.List;
28-
import java.util.Map;
29-
import java.util.Set;
25+
import java.util.*;
3026

3127
import io.micrometer.observation.Observation;
3228
import io.micrometer.observation.ObservationRegistry;
@@ -425,13 +421,16 @@ else if (message.getMessageType() == MessageType.TOOL) {
425421
Document additionalModelRequestFields = ConverseApiUtils
426422
.getChatOptionsAdditionalModelRequestFields(this.defaultOptions, prompt.getOptions());
427423

424+
HashMap<String, String> requestMetadata = new HashMap<>();
425+
428426
return ConverseRequest.builder()
429427
.modelId(updatedRuntimeOptions.getModel())
430428
.inferenceConfig(inferenceConfiguration)
431429
.messages(instructionMessages)
432430
.system(systemMessages)
433431
.additionalModelRequestFields(additionalModelRequestFields)
434432
.toolConfig(toolConfiguration)
433+
.requestMetadata(requestMetadata)
435434
.build();
436435
}
437436

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockChatOptionsTests.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,17 @@ void testBuilderWithAllFields() {
3737
.frequencyPenalty(0.0)
3838
.maxTokens(100)
3939
.presencePenalty(0.0)
40+
.requestParameters(Map.of("requestId", "1234"))
4041
.stopSequences(List.of("stop1", "stop2"))
4142
.temperature(0.7)
4243
.topP(0.8)
4344
.topK(50)
4445
.build();
4546

4647
assertThat(options)
47-
.extracting("model", "frequencyPenalty", "maxTokens", "presencePenalty", "stopSequences", "temperature",
48-
"topP", "topK")
49-
.containsExactly("test-model", 0.0, 100, 0.0, List.of("stop1", "stop2"), 0.7, 0.8, 50);
48+
.extracting("model", "frequencyPenalty", "maxTokens", "presencePenalty", "requestParameters",
49+
"stopSequences", "temperature", "topP", "topK")
50+
.containsExactly("test-model", 0.0, 100, 0.0, Map.of("requestId", "1234"), List.of("stop1", "stop2"), 0.7, 0.8, 50);
5051
}
5152

5253
@Test

0 commit comments

Comments
 (0)