1616
1717import json
1818from os import environ
19- from typing import Any , Callable , Dict , Union
19+ from typing import Any , Callable , Dict , Iterator , Sequence , Union
2020
2121from botocore .eventstream import EventStream , EventStreamError
2222from wrapt import ObjectProxy
3434_StreamErrorCallableT = Callable [[Exception ], None ]
3535
3636
37+ def _decode_tool_use (tool_use ):
38+ # input get sent encoded in json
39+ if "input" in tool_use :
40+ try :
41+ tool_use ["input" ] = json .loads (tool_use ["input" ])
42+ except json .JSONDecodeError :
43+ pass
44+ return tool_use
45+
46+
3747# pylint: disable=abstract-method
3848class ConverseStreamWrapper (ObjectProxy ):
3949 """Wrapper for botocore.eventstream.EventStream"""
@@ -52,7 +62,7 @@ def __init__(
5262 # {"usage": {"inputTokens": 0, "outputTokens": 0}, "stopReason": "finish", "output": {"message": {"role": "", "content": [{"text": ""}]}
5363 self ._response = {}
5464 self ._message = None
55- self ._content_buf = ""
65+ self ._content_block = {}
5666 self ._record_message = False
5767
5868 def __iter__ (self ):
@@ -72,16 +82,32 @@ def _process_event(self, event):
7282 self ._message = {"role" : "assistant" , "content" : []}
7383 return
7484
85+ if "contentBlockStart" in event :
86+ # {'contentBlockStart': {'start': {'toolUse': {'toolUseId': 'id', 'name': 'func_name'}}, 'contentBlockIndex': 1}}
87+ start = event ["contentBlockStart" ].get ("start" , {})
88+ if "toolUse" in start :
89+ tool_use = _decode_tool_use (start ["toolUse" ])
90+ self ._content_block = {"toolUse" : tool_use }
91+ return
92+
7593 if "contentBlockDelta" in event :
7694 # {'contentBlockDelta': {'delta': {'text': "Hello"}, 'contentBlockIndex': 0}}
95+ # {'contentBlockDelta': {'delta': {'toolUse': {'input': '{"location":"Seattle"}'}}, 'contentBlockIndex': 1}}
7796 if self ._record_message :
78- self ._content_buf += (
79- event ["contentBlockDelta" ].get ("delta" , {}).get ("text" , "" )
80- )
97+ delta = event ["contentBlockDelta" ].get ("delta" , {})
98+ if "text" in delta :
99+ self ._content_block .setdefault ("text" , "" )
100+ self ._content_block ["text" ] += delta ["text" ]
101+ elif "toolUse" in delta :
102+ tool_use = _decode_tool_use (delta ["toolUse" ])
103+ self ._content_block ["toolUse" ].update (tool_use )
81104 return
82105
83106 if "contentBlockStop" in event :
84107 # {'contentBlockStop': {'contentBlockIndex': 0}}
108+ if self ._record_message :
109+ self ._message ["content" ].append (self ._content_block )
110+ self ._content_block = {}
85111 return
86112
87113 if "messageStop" in event :
@@ -90,8 +116,6 @@ def _process_event(self, event):
90116 self ._response ["stopReason" ] = stop_reason
91117
92118 if self ._record_message :
93- self ._message ["content" ].append ({"text" : self ._content_buf })
94- self ._content_buf = ""
95119 self ._response ["output" ] = {"message" : self ._message }
96120 self ._record_message = False
97121 self ._message = None
@@ -134,7 +158,8 @@ def __init__(
134158 # {"usage": {"inputTokens": 0, "outputTokens": 0}, "stopReason": "finish", "output": {"message": {"role": "", "content": [{"text": ""}]}
135159 self ._response = {}
136160 self ._message = None
137- self ._content_buf = ""
161+ self ._content_block = {}
162+ self ._tool_json_input_buf = ""
138163 self ._record_message = False
139164
140165 def __iter__ (self ):
@@ -189,6 +214,7 @@ def _process_amazon_titan_chunk(self, chunk):
189214 self ._stream_done_callback (self ._response )
190215
191216 def _process_amazon_nova_chunk (self , chunk ):
217+ # TODO: handle tool calls!
192218 if "messageStart" in chunk :
193219 # {'messageStart': {'role': 'assistant'}}
194220 if chunk ["messageStart" ].get ("role" ) == "assistant" :
@@ -199,9 +225,10 @@ def _process_amazon_nova_chunk(self, chunk):
199225 if "contentBlockDelta" in chunk :
200226 # {'contentBlockDelta': {'delta': {'text': "Hello"}, 'contentBlockIndex': 0}}
201227 if self ._record_message :
202- self ._content_buf += (
203- chunk ["contentBlockDelta" ].get ("delta" , {}).get ("text" , "" )
204- )
228+ delta = chunk ["contentBlockDelta" ].get ("delta" , {})
229+ if "text" in delta :
230+ self ._content_block .setdefault ("text" , "" )
231+ self ._content_block ["text" ] += delta ["text" ]
205232 return
206233
207234 if "contentBlockStop" in chunk :
@@ -214,8 +241,8 @@ def _process_amazon_nova_chunk(self, chunk):
214241 self ._response ["stopReason" ] = stop_reason
215242
216243 if self ._record_message :
217- self ._message ["content" ].append ({ "text" : self ._content_buf } )
218- self ._content_buf = ""
244+ self ._message ["content" ].append (self ._content_block )
245+ self ._content_block = {}
219246 self ._response ["output" ] = {"message" : self ._message }
220247 self ._record_message = False
221248 self ._message = None
@@ -252,18 +279,35 @@ def _process_anthropic_claude_chunk(self, chunk):
252279
253280 if message_type == "content_block_start" :
254281 # {'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}}
282+ # {'type': 'content_block_start', 'index': 1, 'content_block': {'type': 'tool_use', 'id': 'id', 'name': 'func_name', 'input': {}}}
283+ if self ._record_message :
284+ block = chunk .get ("content_block" , {})
285+ if block .get ("type" ) == "text" :
286+ self ._content_block = block
287+ elif block .get ("type" ) == "tool_use" :
288+ self ._content_block = block
255289 return
256290
257291 if message_type == "content_block_delta" :
258292 # {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Here'}}
293+ # {'type': 'content_block_delta', 'index': 1, 'delta': {'type': 'input_json_delta', 'partial_json': ''}}
259294 if self ._record_message :
260- self ._content_buf += chunk .get ("delta" , {}).get ("text" , "" )
295+ delta = chunk .get ("delta" , {})
296+ if delta .get ("type" ) == "text_delta" :
297+ self ._content_block ["text" ] += delta .get ("text" , "" )
298+ elif delta .get ("type" ) == "input_json_delta" :
299+ self ._tool_json_input_buf += delta .get ("partial_json" , "" )
261300 return
262301
263302 if message_type == "content_block_stop" :
264303 # {'type': 'content_block_stop', 'index': 0}
265- self ._message ["content" ].append ({"text" : self ._content_buf })
266- self ._content_buf = ""
304+ if self ._tool_json_input_buf :
305+ self ._content_block ["input" ] = self ._tool_json_input_buf
306+ self ._message ["content" ].append (
307+ _decode_tool_use (self ._content_block )
308+ )
309+ self ._content_block = {}
310+ self ._tool_json_input_buf = ""
267311 return
268312
269313 if message_type == "message_delta" :
@@ -297,16 +341,102 @@ def genai_capture_message_content() -> bool:
297341 return capture_content .lower () == "true"
298342
299343
300- def message_to_event (message : dict [str , Any ], capture_content : bool ) -> Event :
344+ def extract_tool_calls (
345+ message : dict [str , Any ], capture_content : bool
346+ ) -> Sequence [Dict [str , Any ]]:
347+ content = message .get ("content" )
348+ if not content :
349+ return
350+
351+ tool_uses = [item ["toolUse" ] for item in content if "toolUse" in item ]
352+ if not tool_uses :
353+ tool_uses = [
354+ item for item in content if item .get ("type" ) == "tool_use"
355+ ]
356+ tool_id_key = "id"
357+ else :
358+ tool_id_key = "toolUseId"
359+
360+ if not tool_uses :
361+ return
362+
363+ tool_calls = []
364+ for tool_use in tool_uses :
365+ tool_call = {"type" : "function" }
366+ if call_id := tool_use .get (tool_id_key ):
367+ tool_call ["id" ] = call_id
368+
369+ if function_name := tool_use .get ("name" ):
370+ tool_call ["function" ] = {"name" : function_name }
371+
372+ if (function_input := tool_use .get ("input" )) and capture_content :
373+ tool_call .setdefault ("function" , {})
374+ tool_call ["function" ]["arguments" ] = function_input
375+
376+ tool_calls .append (tool_call )
377+ return tool_calls
378+
379+
380+ def extract_tool_results (
381+ message : dict [str , Any ], capture_content : bool
382+ ) -> Iterator [Dict [str , Any ]]:
383+ content = message .get ("content" )
384+ if not content :
385+ return
386+
387+ # Converse format
388+ tool_results = [
389+ item ["toolResult" ] for item in content if "toolResult" in item
390+ ]
391+ # InvokeModel anthropic.claude format
392+ if not tool_results :
393+ tool_results = [
394+ item for item in content if item .get ("type" ) == "tool_result"
395+ ]
396+ tool_id_key = "tool_use_id"
397+ else :
398+ tool_id_key = "toolUseId"
399+
400+ if not tool_results :
401+ return
402+
403+ # if we have a user message with toolResult keys we need to send
404+ # one tool event for each part of the content
405+ for tool_result in tool_results :
406+ body = {}
407+ if tool_id := tool_result .get (tool_id_key ):
408+ body ["id" ] = tool_id
409+ tool_content = tool_result .get ("content" )
410+ if capture_content and tool_content :
411+ body ["content" ] = tool_content
412+
413+ yield body
414+
415+
416+ def message_to_event (
417+ message : dict [str , Any ], capture_content : bool
418+ ) -> Iterator [Event ]:
301419 attributes = {GEN_AI_SYSTEM : GenAiSystemValues .AWS_BEDROCK .value }
302420 role = message .get ("role" )
303421 content = message .get ("content" )
304422
305423 body = {}
306424 if capture_content and content :
307425 body ["content" ] = content
308-
309- return Event (
426+ if role == "assistant" :
427+ # the assistant message contains both tool calls and model thinking content
428+ if tool_calls := extract_tool_calls (message , capture_content ):
429+ body ["tool_calls" ] = tool_calls
430+ elif role == "user" :
431+ # in case of tool calls we send one tool event for tool call and one for the user event
432+ for tool_body in extract_tool_results (message , capture_content ):
433+ yield Event (
434+ name = "gen_ai.tool.message" ,
435+ attributes = attributes ,
436+ body = tool_body ,
437+ )
438+
439+ yield Event (
310440 name = f"gen_ai.{ role } .message" ,
311441 attributes = attributes ,
312442 body = body if body else None ,
@@ -331,8 +461,12 @@ def from_converse(
331461 else :
332462 # amazon.titan does not serialize the role
333463 message = {}
334- if capture_content :
464+
465+ if tool_calls := extract_tool_calls (orig_message , capture_content ):
466+ message ["tool_calls" ] = tool_calls
467+ elif capture_content :
335468 message ["content" ] = orig_message ["content" ]
469+
336470 return cls (message , response ["stopReason" ], index = 0 )
337471
338472 @classmethod
@@ -350,14 +484,11 @@ def from_invoke_amazon_titan(
350484 def from_invoke_anthropic_claude (
351485 cls , response : dict [str , Any ], capture_content : bool
352486 ) -> _Choice :
353- if capture_content :
354- message = {
355- "content" : response ["content" ],
356- "role" : response ["role" ],
357- }
358- else :
359- message = {"role" : response ["role" ]}
360-
487+ message = {"role" : response ["role" ]}
488+ if tool_calls := extract_tool_calls (response , capture_content ):
489+ message ["tool_calls" ] = tool_calls
490+ elif capture_content :
491+ message ["content" ] = response ["content" ]
361492 return cls (message , response ["stop_reason" ], index = 0 )
362493
363494 def _to_body_dict (self ) -> dict [str , Any ]:
0 commit comments