@@ -68,7 +68,6 @@ def diff(
6868 target : exp .Expression ,
6969 matchings : t .List [t .Tuple [exp .Expression , exp .Expression ]] | None = None ,
7070 delta_only : bool = False ,
71- copy : bool = True ,
7271 ** kwargs : t .Any ,
7372) -> t .List [Edit ]:
7473 """
@@ -97,9 +96,6 @@ def diff(
9796 Note: expression references in this list must refer to the same node objects that are
9897 referenced in the source / target trees.
9998 delta_only: excludes all `Keep` nodes from the diff.
100- copy: whether to copy the input expressions.
101- Note: if this is set to false, the caller must ensure that there are no shared references
102- in the two trees, otherwise the diffing algorithm may produce unexpected behavior.
10399 kwargs: additional arguments to pass to the ChangeDistiller instance.
104100
105101 Returns:
@@ -108,44 +104,54 @@ def diff(
108104 expression tree into the target one.
109105 """
110106 matchings = matchings or []
111- matching_ids = {id (n ) for pair in matchings for n in pair }
112107
113108 def compute_node_mappings (
114- original : exp .Expression , copy : exp .Expression
109+ old_nodes : tuple [ exp .Expression , ...], new_nodes : tuple [ exp .Expression , ...]
115110 ) -> t .Dict [int , exp .Expression ]:
116111 node_mapping = {}
117- for old_node , new_node in zip (
118- reversed (tuple (original .walk ())), reversed (tuple (copy .walk ()))
119- ):
120- # We cache the hash of each new node here to speed up equality comparisons. If the input
121- # trees aren't copied, these hashes will be evicted before returning the edit script.
112+ for old_node , new_node in zip (reversed (old_nodes ), reversed (new_nodes )):
122113 new_node ._hash = hash (new_node )
123-
124- old_node_id = id (old_node )
125- if old_node_id in matching_ids :
126- node_mapping [old_node_id ] = new_node
114+ node_mapping [id (old_node )] = new_node
127115
128116 return node_mapping
129117
118+ # if the source and target have any shared objects, that means there's an issue with the ast
119+ # the algorithm won't work because the parent / hierarchies will be inaccurate
120+ source_nodes = tuple (source .walk ())
121+ target_nodes = tuple (target .walk ())
122+ source_ids = {id (n ) for n in source_nodes }
123+ target_ids = {id (n ) for n in target_nodes }
124+
125+ copy = (
126+ len (source_nodes ) != len (source_ids )
127+ or len (target_nodes ) != len (target_ids )
128+ or source_ids & target_ids
129+ )
130+
130131 source_copy = source .copy () if copy else source
131132 target_copy = target .copy () if copy else target
132133
133- node_mappings = {
134- ** compute_node_mappings (source , source_copy ),
135- ** compute_node_mappings (target , target_copy ),
136- }
137- matchings_copy = [(node_mappings [id (s )], node_mappings [id (t )]) for s , t in matchings ]
138-
139- edit_script = ChangeDistiller (** kwargs ).diff (
140- source_copy ,
141- target_copy ,
142- matchings = matchings_copy ,
143- delta_only = delta_only ,
144- )
145-
146- if not copy :
147- for node in chain (source .walk (), target .walk ()):
148- node ._hash = None
134+ try :
135+ # We cache the hash of each new node here to speed up equality comparisons. If the input
136+ # trees aren't copied, these hashes will be evicted before returning the edit script.
137+ if copy and matchings :
138+ source_mapping = compute_node_mappings (source_nodes , tuple (source_copy .walk ()))
139+ target_mapping = compute_node_mappings (target_nodes , tuple (target_copy .walk ()))
140+ matchings = [(source_mapping [id (s )], target_mapping [id (t )]) for s , t in matchings ]
141+ else :
142+ for node in chain (reversed (source_nodes ), reversed (target_nodes )):
143+ node ._hash = hash (node )
144+
145+ edit_script = ChangeDistiller (** kwargs ).diff (
146+ source_copy ,
147+ target_copy ,
148+ matchings = matchings ,
149+ delta_only = delta_only ,
150+ )
151+ finally :
152+ if not copy :
153+ for node in chain (source_nodes , target_nodes ):
154+ node ._hash = None
149155
150156 return edit_script
151157
@@ -186,8 +192,6 @@ def diff(
186192 ) -> t .List [Edit ]:
187193 matchings = matchings or []
188194 pre_matched_nodes = {id (s ): id (t ) for s , t in matchings }
189- if len ({n for pair in pre_matched_nodes .items () for n in pair }) != 2 * len (matchings ):
190- raise ValueError ("Each node can be referenced at most once in the list of matchings" )
191195
192196 self ._source = source
193197 self ._target = target
0 commit comments