22
33import java .util .ArrayList ;
44import java .util .List ;
5+ import java .util .function .Consumer ;
6+ import java .util .function .Function ;
57
68import dev .langchain4j .data .message .AiMessage ;
79import dev .langchain4j .data .message .ChatMessage ;
2426import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockDeltaEvent ;
2527import software .amazon .awssdk .services .bedrockruntime .model .ConversationRole ;
2628import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamMetadataEvent ;
29+ import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamRequest ;
2730import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamResponseHandler ;
31+ import software .amazon .awssdk .services .bedrockruntime .model .InferenceConfiguration ;
2832import software .amazon .awssdk .services .bedrockruntime .model .Message ;
2933import software .amazon .awssdk .services .bedrockruntime .model .MessageStopEvent ;
3034import software .amazon .awssdk .services .bedrockruntime .model .StopReason ;
@@ -75,52 +79,79 @@ public void chat(final ChatRequest chatRequest, final StreamingChatResponseHandl
7579
7680 var responseHandler = ConverseStreamResponseHandler .builder ()
7781 .subscriber (ConverseStreamResponseHandler .Visitor .builder ()
78- .onMessageStop (context :: setStopReason )
79- .onMetadata (context :: updateTokenUsage )
80- .onContentBlockDelta (context :: handleChunk )
82+ .onMessageStop (context . setStopReason () )
83+ .onMetadata (context . updateTokenUsage () )
84+ .onContentBlockDelta (context . handleChunk () )
8185 .build ())
82- .onComplete (context ::handleCompletion )
83- .onError (handler ::onError )
86+ .onComplete (context .handleCompletion ())
87+ .onError (new Consumer <Throwable >() {
88+ @ Override
89+ public void accept (final Throwable throwable ) {
90+ handler .onError (throwable );
91+ }
92+ })
8493 .build ();
8594
86- client .converseStream (request -> request
87- .modelId (modelId )
88- .messages (toBedrockMessages (chatRequest ))
89- .inferenceConfig (c -> {
90- c .maxTokens (config .maxTokens ());
91- config .temperature ().ifPresent (d -> c .temperature ((float ) d ));
92- config .topP ().ifPresent (d -> c .topP ((float ) d ));
93- }), responseHandler );
95+ client .converseStream (new Consumer <ConverseStreamRequest .Builder >() {
96+ @ Override
97+ public void accept (final ConverseStreamRequest .Builder request ) {
98+ request
99+ .modelId (modelId )
100+ .messages (toBedrockMessages (chatRequest ))
101+ .inferenceConfig (createInferenceConfig ());
102+ }
103+ }, responseHandler );
94104 }
95105
96- private List <Message > toBedrockMessages (final ChatRequest chatRequest ) {
97- return chatRequest .messages ().stream ().map (this ::messageTransformer ).toList ();
106+ private Consumer <InferenceConfiguration .Builder > createInferenceConfig () {
107+ return new Consumer <InferenceConfiguration .Builder >() {
108+ @ Override
109+ public void accept (final InferenceConfiguration .Builder builder ) {
110+
111+ builder .maxTokens (config .maxTokens ());
112+ if (config .temperature ().isPresent ()) {
113+ builder .temperature ((float ) config .temperature ().getAsDouble ());
114+ }
115+
116+ if (config .topP ().isPresent ()) {
117+ builder .topP ((float ) config .topP ().getAsDouble ());
118+ }
119+ }
120+ };
98121 }
99122
100- private Message messageTransformer (ChatMessage chatMessage ) {
101-
102- String msg ;
103- ConversationRole role ;
104- if (chatMessage instanceof SystemMessage sm ) {
105- msg = sm .text ();
106- role = ConversationRole .ASSISTANT ;
107- } else if (chatMessage instanceof UserMessage um ) {
108- msg = um .singleText ();
109- role = ConversationRole .USER ;
110- } else if (chatMessage instanceof AiMessage aim ) {
111- msg = aim .text ();
112- role = ConversationRole .USER ;
113- } else if (chatMessage instanceof ToolExecutionResultMessage term ) {
114- msg = term .text ();
115- role = ConversationRole .ASSISTANT ;
116- } else if (chatMessage instanceof CustomMessage cm ) {
117- msg = cm .text ();
118- role = ConversationRole .USER ;
119- } else {
120- throw new IllegalArgumentException (chatMessage == null ? "null" : chatMessage .getClass ().getName ());
121- }
123+ private List <Message > toBedrockMessages (final ChatRequest chatRequest ) {
124+ return chatRequest .messages ().stream ().map (this .messageTransformer ()).toList ();
125+ }
122126
123- return Message .builder ().content (ContentBlock .fromText (msg )).role (role ).build ();
127+ private Function <ChatMessage , Message > messageTransformer () {
128+ return new Function <ChatMessage , Message >() {
129+ @ Override
130+ public Message apply (final ChatMessage chatMessage ) {
131+ String msg ;
132+ ConversationRole role ;
133+ if (chatMessage instanceof SystemMessage sm ) {
134+ msg = sm .text ();
135+ role = ConversationRole .ASSISTANT ;
136+ } else if (chatMessage instanceof UserMessage um ) {
137+ msg = um .singleText ();
138+ role = ConversationRole .USER ;
139+ } else if (chatMessage instanceof AiMessage aim ) {
140+ msg = aim .text ();
141+ role = ConversationRole .USER ;
142+ } else if (chatMessage instanceof ToolExecutionResultMessage term ) {
143+ msg = term .text ();
144+ role = ConversationRole .ASSISTANT ;
145+ } else if (chatMessage instanceof CustomMessage cm ) {
146+ msg = cm .text ();
147+ role = ConversationRole .USER ;
148+ } else {
149+ throw new IllegalArgumentException (chatMessage == null ? "null" : chatMessage .getClass ().getName ());
150+ }
151+
152+ return Message .builder ().content (ContentBlock .fromText (msg )).role (role ).build ();
153+ }
154+ };
124155 }
125156
126157 private class StreamContext {
@@ -133,42 +164,73 @@ public StreamContext(final StreamingChatResponseHandler handler) {
133164 this .handler = handler ;
134165 }
135166
136- public void setStopReason (MessageStopEvent messageStopEvent ) {
137- stopReason = mapFinishReason (messageStopEvent .stopReason ());
167+ public Consumer <MessageStopEvent > setStopReason () {
168+ return new Consumer <MessageStopEvent >() {
169+ @ Override
170+ public void accept (final MessageStopEvent messageStopEvent ) {
171+ stopReason = mapFinishReason (messageStopEvent .stopReason ());
172+ }
173+ };
138174 }
139175
140- public void updateTokenUsage (ConverseStreamMetadataEvent metadataEvent ) {
141- final var usage = metadataEvent .usage ();
142- tokenUsage = tokenUsage .add (new TokenUsage (usage .inputTokens (), usage .outputTokens (), usage .totalTokens ()));
176+ public Consumer <ConverseStreamMetadataEvent > updateTokenUsage () {
177+ return new Consumer <ConverseStreamMetadataEvent >() {
178+ @ Override
179+ public void accept (final ConverseStreamMetadataEvent metadataEvent ) {
180+ final var usage = metadataEvent .usage ();
181+ tokenUsage = tokenUsage .add (
182+ new TokenUsage (usage .inputTokens (), usage .outputTokens (), usage .totalTokens ()));
183+ }
184+ };
143185 }
144186
145- public void handleChunk (ContentBlockDeltaEvent chunk ) {
146- var responseText = chunk .delta ().text ();
147- finalCompletion .append (responseText );
148- handler .onPartialResponse (responseText );
187+ public Consumer <ContentBlockDeltaEvent > handleChunk () {
188+ return new Consumer <ContentBlockDeltaEvent >() {
189+ @ Override
190+ public void accept (final ContentBlockDeltaEvent chunk ) {
191+ var responseText = chunk .delta ().text ();
192+ finalCompletion .append (responseText );
193+ handler .onPartialResponse (responseText );
194+ }
195+ };
149196 }
150197
151- public void handleCompletion () {
152- final var metadata = ChatResponseMetadata .builder ().modelName (modelId ).tokenUsage (tokenUsage )
153- .finishReason (stopReason )
154- .build ();
155-
156- var response = ChatResponse .builder ()
157- .aiMessage (new AiMessage (finalCompletion .toString ()))
158- .metadata (metadata )
159- .build ();
160-
161- handler .onCompleteResponse (response );
198+ public Runnable handleCompletion () {
199+ return new Runnable () {
200+ @ Override
201+ public void run () {
202+ final var metadata = ChatResponseMetadata .builder ().modelName (modelId ).tokenUsage (tokenUsage )
203+ .finishReason (stopReason )
204+ .build ();
205+
206+ var response = ChatResponse .builder ()
207+ .aiMessage (new AiMessage (finalCompletion .toString ()))
208+ .metadata (metadata )
209+ .build ();
210+
211+ handler .onCompleteResponse (response );
212+ }
213+ };
162214 }
163215
164216 private FinishReason mapFinishReason (final StopReason stopReason ) {
165- return switch (stopReason ) {
166- case END_TURN , STOP_SEQUENCE , GUARDRAIL_INTERVENED -> FinishReason .STOP ;
167- case TOOL_USE -> FinishReason .TOOL_EXECUTION ;
168- case MAX_TOKENS -> FinishReason .LENGTH ;
169- case CONTENT_FILTERED -> FinishReason .CONTENT_FILTER ;
170- case UNKNOWN_TO_SDK_VERSION -> FinishReason .OTHER ;
171- };
217+ if (stopReason == null ) {
218+ return FinishReason .OTHER ;
219+ }
220+ switch (stopReason ) {
221+ case END_TURN :
222+ case STOP_SEQUENCE :
223+ case GUARDRAIL_INTERVENED :
224+ return FinishReason .STOP ;
225+ case TOOL_USE :
226+ return FinishReason .TOOL_EXECUTION ;
227+ case MAX_TOKENS :
228+ return FinishReason .LENGTH ;
229+ case CONTENT_FILTERED :
230+ return FinishReason .CONTENT_FILTER ;
231+ default :
232+ return FinishReason .OTHER ;
233+ }
172234 }
173235 }
174236}
0 commit comments