Skip to content

Commit 051c6f0

Browse files
authored
Refactor!!: bundle multiple WHEN [NOT] MATCHED into a exp.WhenSequence (#4495)
* Refactor!!: bundle multiple WHEN [NOT] MATCHED into a exp.WhenSequence * Rename WhenSequence to Whens
1 parent 43975e4 commit 051c6f0

File tree

5 files changed

+29
-16
lines changed

5 files changed

+29
-16
lines changed

sqlglot/dialects/dialect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1547,7 +1547,7 @@ def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
15471547
if alias:
15481548
targets.add(normalize(alias.this))
15491549

1550-
for when in expression.expressions:
1550+
for when in expression.args["whens"].expressions:
15511551
# only remove the target names from the THEN clause
15521552
# theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED
15531553
# ref: https://github.com/TobikoData/sqlmesh/issues/2934

sqlglot/expressions.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6682,16 +6682,22 @@ class Merge(DML):
66826682
"this": True,
66836683
"using": True,
66846684
"on": True,
6685-
"expressions": True,
6685+
"whens": True,
66866686
"with": False,
66876687
"returning": False,
66886688
}
66896689

66906690

6691-
class When(Func):
6691+
class When(Expression):
66926692
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
66936693

66946694

6695+
class Whens(Expression):
6696+
"""Wraps around one or more WHEN [NOT] MATCHED [...] clauses."""
6697+
6698+
arg_types = {"expressions": True}
6699+
6700+
66956701
# https://docs.oracle.com/javadb/10.8.3.0/ref/rrefsqljnextvaluefor.html
66966702
# https://learn.microsoft.com/en-us/sql/t-sql/functions/next-value-for-transact-sql?view=sql-server-ver16
66976703
class NextValueFor(Func):
@@ -7349,14 +7355,17 @@ def merge(
73497355
Returns:
73507356
Merge: The syntax tree for the MERGE statement.
73517357
"""
7358+
expressions = []
7359+
for when_expr in when_exprs:
7360+
expressions.extend(
7361+
maybe_parse(when_expr, dialect=dialect, copy=copy, into=Whens, **opts).expressions
7362+
)
7363+
73527364
merge = Merge(
73537365
this=maybe_parse(into, dialect=dialect, copy=copy, **opts),
73547366
using=maybe_parse(using, dialect=dialect, copy=copy, **opts),
73557367
on=maybe_parse(on, dialect=dialect, copy=copy, **opts),
7356-
expressions=[
7357-
maybe_parse(when_expr, dialect=dialect, copy=copy, into=When, **opts)
7358-
for when_expr in when_exprs
7359-
],
7368+
whens=Whens(expressions=expressions),
73607369
)
73617370
if returning:
73627371
merge = merge.returning(returning, dialect=dialect, copy=False, **opts)

sqlglot/generator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3696,6 +3696,9 @@ def when_sql(self, expression: exp.When) -> str:
36963696
then = self.sql(then_expression)
36973697
return f"WHEN {matched}{source}{condition} THEN {then}"
36983698

3699+
def whens_sql(self, expression: exp.Whens) -> str:
3700+
return self.expressions(expression, sep=" ", indent=False)
3701+
36993702
def merge_sql(self, expression: exp.Merge) -> str:
37003703
table = expression.this
37013704
table_alias = ""
@@ -3708,16 +3711,17 @@ def merge_sql(self, expression: exp.Merge) -> str:
37083711
this = self.sql(table)
37093712
using = f"USING {self.sql(expression, 'using')}"
37103713
on = f"ON {self.sql(expression, 'on')}"
3711-
expressions = self.expressions(expression, sep=" ", indent=False)
3714+
whens = self.sql(expression, "whens")
3715+
37123716
returning = self.sql(expression, "returning")
37133717
if returning:
3714-
expressions = f"{expressions}{returning}"
3718+
whens = f"{whens}{returning}"
37153719

37163720
sep = self.sep()
37173721

37183722
return self.prepend_ctes(
37193723
expression,
3720-
f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{expressions}",
3724+
f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{whens}",
37213725
)
37223726

37233727
@unsupported_args("format")

sqlglot/parser.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ class Parser(metaclass=_Parser):
778778
exp.Table: lambda self: self._parse_table_parts(),
779779
exp.TableAlias: lambda self: self._parse_table_alias(),
780780
exp.Tuple: lambda self: self._parse_value(),
781-
exp.When: lambda self: seq_get(self._parse_when_matched(), 0),
781+
exp.Whens: lambda self: self._parse_when_matched(),
782782
exp.Where: lambda self: self._parse_where(),
783783
exp.Window: lambda self: self._parse_named_window(),
784784
exp.With: lambda self: self._parse_with(),
@@ -7010,11 +7010,11 @@ def _parse_merge(self) -> exp.Merge:
70107010
this=target,
70117011
using=using,
70127012
on=on,
7013-
expressions=self._parse_when_matched(),
7013+
whens=self._parse_when_matched(),
70147014
returning=self._parse_returning(),
70157015
)
70167016

7017-
def _parse_when_matched(self) -> t.List[exp.When]:
7017+
def _parse_when_matched(self) -> exp.Whens:
70187018
whens = []
70197019

70207020
while self._match(TokenType.WHEN):
@@ -7063,7 +7063,7 @@ def _parse_when_matched(self) -> t.List[exp.When]:
70637063
then=then,
70647064
)
70657065
)
7066-
return whens
7066+
return self.expression(exp.Whens, expressions=whens)
70677067

70687068
def _parse_show(self) -> t.Optional[exp.Expression]:
70697069
parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE)

tests/test_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def test_parse_into(self):
2727
self.assertIsInstance(
2828
parse_one(
2929
"WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)",
30-
into=exp.When,
30+
into=exp.Whens,
3131
),
32-
exp.When,
32+
exp.Whens,
3333
)
3434

3535
with self.assertRaises(ParseError) as ctx:

0 commit comments

Comments
 (0)