Skip to content

Commit 302acd4

Browse files
committed
feat: Add Amazon Bedrock adjustments after review
1 parent 731064c commit 302acd4

File tree

2 files changed

+266
-142
lines changed

2 files changed

+266
-142
lines changed

model-providers/bedrock/runtime/src/main/java/io/quarkiverse/langchain4j/bedrock/runtime/BedrockConverseStreamingChatModel.java

Lines changed: 127 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import java.util.ArrayList;
44
import java.util.List;
5+
import java.util.function.Consumer;
6+
import java.util.function.Function;
57

68
import dev.langchain4j.data.message.AiMessage;
79
import dev.langchain4j.data.message.ChatMessage;
@@ -24,7 +26,9 @@
2426
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
2527
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
2628
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
29+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
2730
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
31+
import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
2832
import software.amazon.awssdk.services.bedrockruntime.model.Message;
2933
import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
3034
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
@@ -75,52 +79,79 @@ public void chat(final ChatRequest chatRequest, final StreamingChatResponseHandl
7579

7680
var responseHandler = ConverseStreamResponseHandler.builder()
7781
.subscriber(ConverseStreamResponseHandler.Visitor.builder()
78-
.onMessageStop(context::setStopReason)
79-
.onMetadata(context::updateTokenUsage)
80-
.onContentBlockDelta(context::handleChunk)
82+
.onMessageStop(context.setStopReason())
83+
.onMetadata(context.updateTokenUsage())
84+
.onContentBlockDelta(context.handleChunk())
8185
.build())
82-
.onComplete(context::handleCompletion)
83-
.onError(handler::onError)
86+
.onComplete(context.handleCompletion())
87+
.onError(new Consumer<Throwable>() {
88+
@Override
89+
public void accept(final Throwable throwable) {
90+
handler.onError(throwable);
91+
}
92+
})
8493
.build();
8594

86-
client.converseStream(request -> request
87-
.modelId(modelId)
88-
.messages(toBedrockMessages(chatRequest))
89-
.inferenceConfig(c -> {
90-
c.maxTokens(config.maxTokens());
91-
config.temperature().ifPresent(d -> c.temperature((float) d));
92-
config.topP().ifPresent(d -> c.topP((float) d));
93-
}), responseHandler);
95+
client.converseStream(new Consumer<ConverseStreamRequest.Builder>() {
96+
@Override
97+
public void accept(final ConverseStreamRequest.Builder request) {
98+
request
99+
.modelId(modelId)
100+
.messages(toBedrockMessages(chatRequest))
101+
.inferenceConfig(createInferenceConfig());
102+
}
103+
}, responseHandler);
94104
}
95105

96-
private List<Message> toBedrockMessages(final ChatRequest chatRequest) {
97-
return chatRequest.messages().stream().map(this::messageTransformer).toList();
106+
private Consumer<InferenceConfiguration.Builder> createInferenceConfig() {
107+
return new Consumer<InferenceConfiguration.Builder>() {
108+
@Override
109+
public void accept(final InferenceConfiguration.Builder builder) {
110+
111+
builder.maxTokens(config.maxTokens());
112+
if (config.temperature().isPresent()) {
113+
builder.temperature((float) config.temperature().getAsDouble());
114+
}
115+
116+
if (config.topP().isPresent()) {
117+
builder.topP((float) config.topP().getAsDouble());
118+
}
119+
}
120+
};
98121
}
99122

100-
private Message messageTransformer(ChatMessage chatMessage) {
101-
102-
String msg;
103-
ConversationRole role;
104-
if (chatMessage instanceof SystemMessage sm) {
105-
msg = sm.text();
106-
role = ConversationRole.ASSISTANT;
107-
} else if (chatMessage instanceof UserMessage um) {
108-
msg = um.singleText();
109-
role = ConversationRole.USER;
110-
} else if (chatMessage instanceof AiMessage aim) {
111-
msg = aim.text();
112-
role = ConversationRole.USER;
113-
} else if (chatMessage instanceof ToolExecutionResultMessage term) {
114-
msg = term.text();
115-
role = ConversationRole.ASSISTANT;
116-
} else if (chatMessage instanceof CustomMessage cm) {
117-
msg = cm.text();
118-
role = ConversationRole.USER;
119-
} else {
120-
throw new IllegalArgumentException(chatMessage == null ? "null" : chatMessage.getClass().getName());
121-
}
123+
private List<Message> toBedrockMessages(final ChatRequest chatRequest) {
124+
return chatRequest.messages().stream().map(this.messageTransformer()).toList();
125+
}
122126

123-
return Message.builder().content(ContentBlock.fromText(msg)).role(role).build();
127+
private Function<ChatMessage, Message> messageTransformer() {
128+
return new Function<ChatMessage, Message>() {
129+
@Override
130+
public Message apply(final ChatMessage chatMessage) {
131+
String msg;
132+
ConversationRole role;
133+
if (chatMessage instanceof SystemMessage sm) {
134+
msg = sm.text();
135+
role = ConversationRole.ASSISTANT;
136+
} else if (chatMessage instanceof UserMessage um) {
137+
msg = um.singleText();
138+
role = ConversationRole.USER;
139+
} else if (chatMessage instanceof AiMessage aim) {
140+
msg = aim.text();
141+
role = ConversationRole.USER;
142+
} else if (chatMessage instanceof ToolExecutionResultMessage term) {
143+
msg = term.text();
144+
role = ConversationRole.ASSISTANT;
145+
} else if (chatMessage instanceof CustomMessage cm) {
146+
msg = cm.text();
147+
role = ConversationRole.USER;
148+
} else {
149+
throw new IllegalArgumentException(chatMessage == null ? "null" : chatMessage.getClass().getName());
150+
}
151+
152+
return Message.builder().content(ContentBlock.fromText(msg)).role(role).build();
153+
}
154+
};
124155
}
125156

126157
private class StreamContext {
@@ -133,42 +164,73 @@ public StreamContext(final StreamingChatResponseHandler handler) {
133164
this.handler = handler;
134165
}
135166

136-
public void setStopReason(MessageStopEvent messageStopEvent) {
137-
stopReason = mapFinishReason(messageStopEvent.stopReason());
167+
public Consumer<MessageStopEvent> setStopReason() {
168+
return new Consumer<MessageStopEvent>() {
169+
@Override
170+
public void accept(final MessageStopEvent messageStopEvent) {
171+
stopReason = mapFinishReason(messageStopEvent.stopReason());
172+
}
173+
};
138174
}
139175

140-
public void updateTokenUsage(ConverseStreamMetadataEvent metadataEvent) {
141-
final var usage = metadataEvent.usage();
142-
tokenUsage = tokenUsage.add(new TokenUsage(usage.inputTokens(), usage.outputTokens(), usage.totalTokens()));
176+
public Consumer<ConverseStreamMetadataEvent> updateTokenUsage() {
177+
return new Consumer<ConverseStreamMetadataEvent>() {
178+
@Override
179+
public void accept(final ConverseStreamMetadataEvent metadataEvent) {
180+
final var usage = metadataEvent.usage();
181+
tokenUsage = tokenUsage.add(
182+
new TokenUsage(usage.inputTokens(), usage.outputTokens(), usage.totalTokens()));
183+
}
184+
};
143185
}
144186

145-
public void handleChunk(ContentBlockDeltaEvent chunk) {
146-
var responseText = chunk.delta().text();
147-
finalCompletion.append(responseText);
148-
handler.onPartialResponse(responseText);
187+
public Consumer<ContentBlockDeltaEvent> handleChunk() {
188+
return new Consumer<ContentBlockDeltaEvent>() {
189+
@Override
190+
public void accept(final ContentBlockDeltaEvent chunk) {
191+
var responseText = chunk.delta().text();
192+
finalCompletion.append(responseText);
193+
handler.onPartialResponse(responseText);
194+
}
195+
};
149196
}
150197

151-
public void handleCompletion() {
152-
final var metadata = ChatResponseMetadata.builder().modelName(modelId).tokenUsage(tokenUsage)
153-
.finishReason(stopReason)
154-
.build();
155-
156-
var response = ChatResponse.builder()
157-
.aiMessage(new AiMessage(finalCompletion.toString()))
158-
.metadata(metadata)
159-
.build();
160-
161-
handler.onCompleteResponse(response);
198+
public Runnable handleCompletion() {
199+
return new Runnable() {
200+
@Override
201+
public void run() {
202+
final var metadata = ChatResponseMetadata.builder().modelName(modelId).tokenUsage(tokenUsage)
203+
.finishReason(stopReason)
204+
.build();
205+
206+
var response = ChatResponse.builder()
207+
.aiMessage(new AiMessage(finalCompletion.toString()))
208+
.metadata(metadata)
209+
.build();
210+
211+
handler.onCompleteResponse(response);
212+
}
213+
};
162214
}
163215

164216
private FinishReason mapFinishReason(final StopReason stopReason) {
165-
return switch (stopReason) {
166-
case END_TURN, STOP_SEQUENCE, GUARDRAIL_INTERVENED -> FinishReason.STOP;
167-
case TOOL_USE -> FinishReason.TOOL_EXECUTION;
168-
case MAX_TOKENS -> FinishReason.LENGTH;
169-
case CONTENT_FILTERED -> FinishReason.CONTENT_FILTER;
170-
case UNKNOWN_TO_SDK_VERSION -> FinishReason.OTHER;
171-
};
217+
if (stopReason == null) {
218+
return FinishReason.OTHER;
219+
}
220+
switch (stopReason) {
221+
case END_TURN:
222+
case STOP_SEQUENCE:
223+
case GUARDRAIL_INTERVENED:
224+
return FinishReason.STOP;
225+
case TOOL_USE:
226+
return FinishReason.TOOL_EXECUTION;
227+
case MAX_TOKENS:
228+
return FinishReason.LENGTH;
229+
case CONTENT_FILTERED:
230+
return FinishReason.CONTENT_FILTER;
231+
default:
232+
return FinishReason.OTHER;
233+
}
172234
}
173235
}
174236
}

0 commit comments

Comments
 (0)