1616
1717import json
1818from os import environ
19- from typing import Any , Callable , Dict , Union
19+ from typing import Any , Callable , Dict , Iterator , Sequence , Union
2020
2121from botocore .eventstream import EventStream , EventStreamError
2222from wrapt import ObjectProxy
3434_StreamErrorCallableT = Callable [[Exception ], None ]
3535
3636
37+ def _decode_tool_use (tool_use ):
38+ # input get sent encoded in json
39+ if "input" in tool_use :
40+ try :
41+ tool_use ["input" ] = json .loads (tool_use ["input" ])
42+ except json .JSONDecodeError :
43+ pass
44+ return tool_use
45+
46+
3747# pylint: disable=abstract-method
3848class ConverseStreamWrapper (ObjectProxy ):
3949 """Wrapper for botocore.eventstream.EventStream"""
@@ -52,7 +62,7 @@ def __init__(
5262 # {"usage": {"inputTokens": 0, "outputTokens": 0}, "stopReason": "finish", "output": {"message": {"role": "", "content": [{"text": ""}]}
5363 self ._response = {}
5464 self ._message = None
55- self ._content_buf = ""
65+ self ._content_block = {}
5666 self ._record_message = False
5767
5868 def __iter__ (self ):
@@ -65,23 +75,40 @@ def __iter__(self):
6575 raise
6676
6777 def _process_event (self , event ):
78+ # pylint: disable=too-many-branches
6879 if "messageStart" in event :
6980 # {'messageStart': {'role': 'assistant'}}
7081 if event ["messageStart" ].get ("role" ) == "assistant" :
7182 self ._record_message = True
7283 self ._message = {"role" : "assistant" , "content" : []}
7384 return
7485
86+ if "contentBlockStart" in event :
87+ # {'contentBlockStart': {'start': {'toolUse': {'toolUseId': 'id', 'name': 'func_name'}}, 'contentBlockIndex': 1}}
88+ start = event ["contentBlockStart" ].get ("start" , {})
89+ if "toolUse" in start :
90+ tool_use = _decode_tool_use (start ["toolUse" ])
91+ self ._content_block = {"toolUse" : tool_use }
92+ return
93+
7594 if "contentBlockDelta" in event :
7695 # {'contentBlockDelta': {'delta': {'text': "Hello"}, 'contentBlockIndex': 0}}
96+ # {'contentBlockDelta': {'delta': {'toolUse': {'input': '{"location":"Seattle"}'}}, 'contentBlockIndex': 1}}
7797 if self ._record_message :
78- self ._content_buf += (
79- event ["contentBlockDelta" ].get ("delta" , {}).get ("text" , "" )
80- )
98+ delta = event ["contentBlockDelta" ].get ("delta" , {})
99+ if "text" in delta :
100+ self ._content_block .setdefault ("text" , "" )
101+ self ._content_block ["text" ] += delta ["text" ]
102+ elif "toolUse" in delta :
103+ tool_use = _decode_tool_use (delta ["toolUse" ])
104+ self ._content_block ["toolUse" ].update (tool_use )
81105 return
82106
83107 if "contentBlockStop" in event :
84108 # {'contentBlockStop': {'contentBlockIndex': 0}}
109+ if self ._record_message :
110+ self ._message ["content" ].append (self ._content_block )
111+ self ._content_block = {}
85112 return
86113
87114 if "messageStop" in event :
@@ -90,8 +117,6 @@ def _process_event(self, event):
90117 self ._response ["stopReason" ] = stop_reason
91118
92119 if self ._record_message :
93- self ._message ["content" ].append ({"text" : self ._content_buf })
94- self ._content_buf = ""
95120 self ._response ["output" ] = {"message" : self ._message }
96121 self ._record_message = False
97122 self ._message = None
@@ -134,7 +159,8 @@ def __init__(
134159 # {"usage": {"inputTokens": 0, "outputTokens": 0}, "stopReason": "finish", "output": {"message": {"role": "", "content": [{"text": ""}]}
135160 self ._response = {}
136161 self ._message = None
137- self ._content_buf = ""
162+ self ._content_block = {}
163+ self ._tool_json_input_buf = ""
138164 self ._record_message = False
139165
140166 def __iter__ (self ):
@@ -189,6 +215,8 @@ def _process_amazon_titan_chunk(self, chunk):
189215 self ._stream_done_callback (self ._response )
190216
191217 def _process_amazon_nova_chunk (self , chunk ):
218+ # pylint: disable=too-many-branches
219+ # TODO: handle tool calls!
192220 if "messageStart" in chunk :
193221 # {'messageStart': {'role': 'assistant'}}
194222 if chunk ["messageStart" ].get ("role" ) == "assistant" :
@@ -199,9 +227,10 @@ def _process_amazon_nova_chunk(self, chunk):
199227 if "contentBlockDelta" in chunk :
200228 # {'contentBlockDelta': {'delta': {'text': "Hello"}, 'contentBlockIndex': 0}}
201229 if self ._record_message :
202- self ._content_buf += (
203- chunk ["contentBlockDelta" ].get ("delta" , {}).get ("text" , "" )
204- )
230+ delta = chunk ["contentBlockDelta" ].get ("delta" , {})
231+ if "text" in delta :
232+ self ._content_block .setdefault ("text" , "" )
233+ self ._content_block ["text" ] += delta ["text" ]
205234 return
206235
207236 if "contentBlockStop" in chunk :
@@ -214,8 +243,8 @@ def _process_amazon_nova_chunk(self, chunk):
214243 self ._response ["stopReason" ] = stop_reason
215244
216245 if self ._record_message :
217- self ._message ["content" ].append ({ "text" : self ._content_buf } )
218- self ._content_buf = ""
246+ self ._message ["content" ].append (self ._content_block )
247+ self ._content_block = {}
219248 self ._response ["output" ] = {"message" : self ._message }
220249 self ._record_message = False
221250 self ._message = None
@@ -235,7 +264,7 @@ def _process_amazon_nova_chunk(self, chunk):
235264 return
236265
237266 def _process_anthropic_claude_chunk (self , chunk ):
238- # pylint: disable=too-many-return-statements
267+ # pylint: disable=too-many-return-statements,too-many-branches
239268 if not (message_type := chunk .get ("type" )):
240269 return
241270
@@ -252,18 +281,35 @@ def _process_anthropic_claude_chunk(self, chunk):
252281
253282 if message_type == "content_block_start" :
254283 # {'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}}
284+ # {'type': 'content_block_start', 'index': 1, 'content_block': {'type': 'tool_use', 'id': 'id', 'name': 'func_name', 'input': {}}}
285+ if self ._record_message :
286+ block = chunk .get ("content_block" , {})
287+ if block .get ("type" ) == "text" :
288+ self ._content_block = block
289+ elif block .get ("type" ) == "tool_use" :
290+ self ._content_block = block
255291 return
256292
257293 if message_type == "content_block_delta" :
258294 # {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Here'}}
295+ # {'type': 'content_block_delta', 'index': 1, 'delta': {'type': 'input_json_delta', 'partial_json': ''}}
259296 if self ._record_message :
260- self ._content_buf += chunk .get ("delta" , {}).get ("text" , "" )
297+ delta = chunk .get ("delta" , {})
298+ if delta .get ("type" ) == "text_delta" :
299+ self ._content_block ["text" ] += delta .get ("text" , "" )
300+ elif delta .get ("type" ) == "input_json_delta" :
301+ self ._tool_json_input_buf += delta .get ("partial_json" , "" )
261302 return
262303
263304 if message_type == "content_block_stop" :
264305 # {'type': 'content_block_stop', 'index': 0}
265- self ._message ["content" ].append ({"text" : self ._content_buf })
266- self ._content_buf = ""
306+ if self ._tool_json_input_buf :
307+ self ._content_block ["input" ] = self ._tool_json_input_buf
308+ self ._message ["content" ].append (
309+ _decode_tool_use (self ._content_block )
310+ )
311+ self ._content_block = {}
312+ self ._tool_json_input_buf = ""
267313 return
268314
269315 if message_type == "message_delta" :
@@ -297,16 +343,102 @@ def genai_capture_message_content() -> bool:
297343 return capture_content .lower () == "true"
298344
299345
300- def message_to_event (message : dict [str , Any ], capture_content : bool ) -> Event :
346+ def extract_tool_calls (
347+ message : dict [str , Any ], capture_content : bool
348+ ) -> Sequence [Dict [str , Any ]] | None :
349+ content = message .get ("content" )
350+ if not content :
351+ return None
352+
353+ tool_uses = [item ["toolUse" ] for item in content if "toolUse" in item ]
354+ if not tool_uses :
355+ tool_uses = [
356+ item for item in content if item .get ("type" ) == "tool_use"
357+ ]
358+ tool_id_key = "id"
359+ else :
360+ tool_id_key = "toolUseId"
361+
362+ if not tool_uses :
363+ return None
364+
365+ tool_calls = []
366+ for tool_use in tool_uses :
367+ tool_call = {"type" : "function" }
368+ if call_id := tool_use .get (tool_id_key ):
369+ tool_call ["id" ] = call_id
370+
371+ if function_name := tool_use .get ("name" ):
372+ tool_call ["function" ] = {"name" : function_name }
373+
374+ if (function_input := tool_use .get ("input" )) and capture_content :
375+ tool_call .setdefault ("function" , {})
376+ tool_call ["function" ]["arguments" ] = function_input
377+
378+ tool_calls .append (tool_call )
379+ return tool_calls
380+
381+
382+ def extract_tool_results (
383+ message : dict [str , Any ], capture_content : bool
384+ ) -> Iterator [Dict [str , Any ]]:
385+ content = message .get ("content" )
386+ if not content :
387+ return
388+
389+ # Converse format
390+ tool_results = [
391+ item ["toolResult" ] for item in content if "toolResult" in item
392+ ]
393+ # InvokeModel anthropic.claude format
394+ if not tool_results :
395+ tool_results = [
396+ item for item in content if item .get ("type" ) == "tool_result"
397+ ]
398+ tool_id_key = "tool_use_id"
399+ else :
400+ tool_id_key = "toolUseId"
401+
402+ if not tool_results :
403+ return
404+
405+ # if we have a user message with toolResult keys we need to send
406+ # one tool event for each part of the content
407+ for tool_result in tool_results :
408+ body = {}
409+ if tool_id := tool_result .get (tool_id_key ):
410+ body ["id" ] = tool_id
411+ tool_content = tool_result .get ("content" )
412+ if capture_content and tool_content :
413+ body ["content" ] = tool_content
414+
415+ yield body
416+
417+
418+ def message_to_event (
419+ message : dict [str , Any ], capture_content : bool
420+ ) -> Iterator [Event ]:
301421 attributes = {GEN_AI_SYSTEM : GenAiSystemValues .AWS_BEDROCK .value }
302422 role = message .get ("role" )
303423 content = message .get ("content" )
304424
305425 body = {}
306426 if capture_content and content :
307427 body ["content" ] = content
308-
309- return Event (
428+ if role == "assistant" :
429+ # the assistant message contains both tool calls and model thinking content
430+ if tool_calls := extract_tool_calls (message , capture_content ):
431+ body ["tool_calls" ] = tool_calls
432+ elif role == "user" :
433+ # in case of tool calls we send one tool event for tool call and one for the user event
434+ for tool_body in extract_tool_results (message , capture_content ):
435+ yield Event (
436+ name = "gen_ai.tool.message" ,
437+ attributes = attributes ,
438+ body = tool_body ,
439+ )
440+
441+ yield Event (
310442 name = f"gen_ai.{ role } .message" ,
311443 attributes = attributes ,
312444 body = body if body else None ,
@@ -331,8 +463,12 @@ def from_converse(
331463 else :
332464 # amazon.titan does not serialize the role
333465 message = {}
334- if capture_content :
466+
467+ if tool_calls := extract_tool_calls (orig_message , capture_content ):
468+ message ["tool_calls" ] = tool_calls
469+ elif capture_content :
335470 message ["content" ] = orig_message ["content" ]
471+
336472 return cls (message , response ["stopReason" ], index = 0 )
337473
338474 @classmethod
@@ -350,14 +486,11 @@ def from_invoke_amazon_titan(
350486 def from_invoke_anthropic_claude (
351487 cls , response : dict [str , Any ], capture_content : bool
352488 ) -> _Choice :
353- if capture_content :
354- message = {
355- "content" : response ["content" ],
356- "role" : response ["role" ],
357- }
358- else :
359- message = {"role" : response ["role" ]}
360-
489+ message = {"role" : response ["role" ]}
490+ if tool_calls := extract_tool_calls (response , capture_content ):
491+ message ["tool_calls" ] = tool_calls
492+ elif capture_content :
493+ message ["content" ] = response ["content" ]
361494 return cls (message , response ["stop_reason" ], index = 0 )
362495
363496 def _to_body_dict (self ) -> dict [str , Any ]:
0 commit comments