@@ -530,47 +530,63 @@ def compare_results(
530530 return results
531531
532532
533- def merge_overlapping_debug_handles (intermediate_outputs : Dict [DebugHandle , Any ]):
533+ def merge_overlapping_debug_handles (
534+ intermediate_outputs : Dict [DebugHandle , Any ]
535+ ) -> Dict [DebugHandle , Any ]:
534536 """
535- Merge overlapping debug handles int a single key
537+ Merges overlapping debug handles into a single key in the dict.
538+ For each debug handle, this function checks for overlaps with existing keys in the merged dict.
539+ If overlaps are found, it combines the overlapping keys into a single key by taking the union of their elements.
540+ The value associated with the merged key is determined by the debug handle with the highest last element.
536541 """
542+
537543 if len (intermediate_outputs ) == 0 :
538- return
539- # Extract and normalize into (start, end, val)
540- intervals = [(min (key ), max (key ), val ) for key , val in intermediate_outputs .items ()]
541- intervals .sort (key = lambda x : x [0 ])
542-
543- # Merge overlapping debug_hanldes, picking the last value
544- merged_intermediate_outputs = []
545- cur_start , cur_end , cur_val = intervals [0 ]
546- for start , end , val in intervals [1 :]:
547- if start <= cur_end : # Overlaps
548- if end > cur_end : # Extend if this one goes further
549- cur_end , cur_val = end , val
544+ return {}
550545
551- else :
552- merged_intermediate_outputs .append ((cur_start , cur_end , cur_val ))
553- cur_start , cur_end , cur_val = start , end , val
554- merged_intermediate_outputs .append ((cur_start , cur_end , cur_val ))
546+ merged : Dict [DebugHandle , Any ] = {}
547+
548+ for debug_handle , value in intermediate_outputs .items ():
549+ debug_handle_set = set (debug_handle )
550+ curr_debug_handle , last_value = debug_handle , value
551+
552+ # collect any existing keys that overlap with the current key
553+ to_remove = []
554+ for existing_debug_handle , existing_value in merged .items ():
555+ if debug_handle_set .intersection (set (existing_debug_handle )):
556+ # abosrb their ints
557+ debug_handle_set |= set (existing_debug_handle )
558+ if existing_debug_handle [- 1 ] > curr_debug_handle [- 1 ]:
559+ curr_debug_handle , last_value = (
560+ existing_debug_handle ,
561+ existing_value ,
562+ )
563+ to_remove .append (existing_debug_handle )
555564
556- # Clear original one and populate with merged keys (value will point to the same object)
557- intermediate_outputs .clear ()
558- for start , end , val in merged_intermediate_outputs :
559- intermediate_outputs [tuple (range (start , end + 1 ))] = val
565+ # remove all the keys that overlap with the current key
566+ for debug_handle in to_remove :
567+ merged .pop (debug_handle )
568+
569+ # add the current key to the merged one
570+ new_debug_handle = tuple (sorted (debug_handle_set ))
571+ merged [new_debug_handle ] = last_value
572+
573+ # Sort the merged debug handles in ascending order based on their last element
574+ # TODO: Consider adding more logic to align the order with the execution order
575+ return dict (sorted (merged .items (), key = lambda item : item [0 ][- 1 ]))
560576
561577
562578def _debug_handles_have_overlap (
563- aot_debug_hanlde : DebugHandle , runtime_debug_handle : DebugHandle
579+ debug_handle : DebugHandle , target_debug_handle : DebugHandle
564580) -> bool :
565581 """
566- Check if the AOT debug handle and the runtime debug handle have any overlap.
582+ Check if the debug handle and the target runtime debug handle have any overlap.
567583 """
568- aot_set = set (aot_debug_hanlde )
569- runtime_set = set (runtime_debug_handle )
584+ aot_set = set (debug_handle )
585+ runtime_set = set (target_debug_handle )
570586 return len (aot_set .intersection (runtime_set )) > 0
571587
572588
573- def _combine_debug_hanldes (debug_handles : List [DebugHandle ]) -> DebugHandle :
589+ def _combine_debug_handles (debug_handles : List [DebugHandle ]) -> DebugHandle :
574590 """Combine multiple debug handles into one debug handle"""
575591 combined_debug_handles_set = set ()
576592 for debug_handle in debug_handles :
@@ -584,7 +600,7 @@ def _combine_overlapped_intermediate_outputs(
584600 """Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
585601 debug_handles = [debug_handle for debug_handle , _ in nodes ]
586602 outputs = [output for _ , output in nodes ]
587- combined_debug_handle = _combine_debug_hanldes (debug_handles )
603+ combined_debug_handle = _combine_debug_handles (debug_handles )
588604 output = outputs [- 1 ] # Pick the last one
589605 return combined_debug_handle , output
590606
@@ -673,8 +689,10 @@ def map_runtime_aot_intermediate_outputs(
673689 from runtime intermediate output to AOT intermediate output
674690 """
675691 # Merge overlapping debug handles
676- merge_overlapping_debug_handles (aot_intermediate_outputs )
677- merge_overlapping_debug_handles (runtime_intermediate_outputs )
692+ aot_intermediate_outputs = merge_overlapping_debug_handles (aot_intermediate_outputs )
693+ runtime_intermediate_outputs = merge_overlapping_debug_handles (
694+ runtime_intermediate_outputs
695+ )
678696
679697 # Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
680698 nodes , edges = _create_debug_handle_overlap_graph (
0 commit comments