Skip to content

Commit faf6d41

Browse files
authored
fix: allow duplicate nodes in matchings (#4817)
1 parent 83e6a87 commit faf6d41

File tree

2 files changed

+56
-35
lines changed

2 files changed

+56
-35
lines changed

sqlglot/diff.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def diff(
6868
target: exp.Expression,
6969
matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None,
7070
delta_only: bool = False,
71-
copy: bool = True,
7271
**kwargs: t.Any,
7372
) -> t.List[Edit]:
7473
"""
@@ -97,9 +96,6 @@ def diff(
9796
Note: expression references in this list must refer to the same node objects that are
9897
referenced in the source / target trees.
9998
delta_only: excludes all `Keep` nodes from the diff.
100-
copy: whether to copy the input expressions.
101-
Note: if this is set to false, the caller must ensure that there are no shared references
102-
in the two trees, otherwise the diffing algorithm may produce unexpected behavior.
10399
kwargs: additional arguments to pass to the ChangeDistiller instance.
104100
105101
Returns:
@@ -108,44 +104,54 @@ def diff(
108104
expression tree into the target one.
109105
"""
110106
matchings = matchings or []
111-
matching_ids = {id(n) for pair in matchings for n in pair}
112107

113108
def compute_node_mappings(
114-
original: exp.Expression, copy: exp.Expression
109+
old_nodes: tuple[exp.Expression, ...], new_nodes: tuple[exp.Expression, ...]
115110
) -> t.Dict[int, exp.Expression]:
116111
node_mapping = {}
117-
for old_node, new_node in zip(
118-
reversed(tuple(original.walk())), reversed(tuple(copy.walk()))
119-
):
120-
# We cache the hash of each new node here to speed up equality comparisons. If the input
121-
# trees aren't copied, these hashes will be evicted before returning the edit script.
112+
for old_node, new_node in zip(reversed(old_nodes), reversed(new_nodes)):
122113
new_node._hash = hash(new_node)
123-
124-
old_node_id = id(old_node)
125-
if old_node_id in matching_ids:
126-
node_mapping[old_node_id] = new_node
114+
node_mapping[id(old_node)] = new_node
127115

128116
return node_mapping
129117

118+
# if the source and target have any shared objects, that means there's an issue with the ast
119+
# the algorithm won't work because the parent / hierarchies will be inaccurate
120+
source_nodes = tuple(source.walk())
121+
target_nodes = tuple(target.walk())
122+
source_ids = {id(n) for n in source_nodes}
123+
target_ids = {id(n) for n in target_nodes}
124+
125+
copy = (
126+
len(source_nodes) != len(source_ids)
127+
or len(target_nodes) != len(target_ids)
128+
or source_ids & target_ids
129+
)
130+
130131
source_copy = source.copy() if copy else source
131132
target_copy = target.copy() if copy else target
132133

133-
node_mappings = {
134-
**compute_node_mappings(source, source_copy),
135-
**compute_node_mappings(target, target_copy),
136-
}
137-
matchings_copy = [(node_mappings[id(s)], node_mappings[id(t)]) for s, t in matchings]
138-
139-
edit_script = ChangeDistiller(**kwargs).diff(
140-
source_copy,
141-
target_copy,
142-
matchings=matchings_copy,
143-
delta_only=delta_only,
144-
)
145-
146-
if not copy:
147-
for node in chain(source.walk(), target.walk()):
148-
node._hash = None
134+
try:
135+
# We cache the hash of each new node here to speed up equality comparisons. If the input
136+
# trees aren't copied, these hashes will be evicted before returning the edit script.
137+
if copy and matchings:
138+
source_mapping = compute_node_mappings(source_nodes, tuple(source_copy.walk()))
139+
target_mapping = compute_node_mappings(target_nodes, tuple(target_copy.walk()))
140+
matchings = [(source_mapping[id(s)], target_mapping[id(t)]) for s, t in matchings]
141+
else:
142+
for node in chain(reversed(source_nodes), reversed(target_nodes)):
143+
node._hash = hash(node)
144+
145+
edit_script = ChangeDistiller(**kwargs).diff(
146+
source_copy,
147+
target_copy,
148+
matchings=matchings,
149+
delta_only=delta_only,
150+
)
151+
finally:
152+
if not copy:
153+
for node in chain(source_nodes, target_nodes):
154+
node._hash = None
149155

150156
return edit_script
151157

@@ -186,8 +192,6 @@ def diff(
186192
) -> t.List[Edit]:
187193
matchings = matchings or []
188194
pre_matched_nodes = {id(s): id(t) for s, t in matchings}
189-
if len({n for pair in pre_matched_nodes.items() for n in pair}) != 2 * len(matchings):
190-
raise ValueError("Each node can be referenced at most once in the list of matchings")
191195

192196
self._source = source
193197
self._target = target

tests/test_diff.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,27 @@ def test_pre_matchings(self):
240240
],
241241
)
242242

243-
with self.assertRaises(ValueError):
243+
self._validate_delta_only(
244244
diff_delta_only(
245245
expr_src, expr_tgt, matchings=[(expr_src, expr_tgt), (expr_src, expr_tgt)]
246-
)
246+
),
247+
[
248+
Insert(expression=exp.Literal.number(2)),
249+
Insert(expression=exp.Literal.number(3)),
250+
Insert(expression=exp.Literal.number(4)),
251+
],
252+
)
253+
254+
expr_tgt.selects[0].replace(expr_src.selects[0])
255+
256+
self._validate_delta_only(
257+
diff_delta_only(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]),
258+
[
259+
Insert(expression=exp.Literal.number(2)),
260+
Insert(expression=exp.Literal.number(3)),
261+
Insert(expression=exp.Literal.number(4)),
262+
],
263+
)
247264

248265
def test_identifier(self):
249266
expr_src = parse_one("SELECT a FROM tbl")

0 commit comments

Comments
 (0)