Skip to content

Commit c594b63

Browse files
authored
fix: Add MAX_BY & MIN_BY to FUNCTION_PARSER (#5021)
* fix: Add MAX_BY & MIN_BY to FUNCTION_PARSER * PR Feedback 1
1 parent 49dd0f3 commit c594b63

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

sqlglot/parser.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,12 @@ def _parse_partitioned_by_bucket_or_truncate(self) -> exp.Expression:
11841184
KEY_VALUE_DEFINITIONS = (exp.Alias, exp.EQ, exp.PropertyEQ, exp.Slice)
11851185

11861186
FUNCTION_PARSERS = {
1187+
**{
1188+
name: lambda self: self._parse_max_min_by(exp.ArgMax) for name in exp.ArgMax.sql_names()
1189+
},
1190+
**{
1191+
name: lambda self: self._parse_max_min_by(exp.ArgMin) for name in exp.ArgMin.sql_names()
1192+
},
11871193
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
11881194
"CEIL": lambda self: self._parse_ceil_floor(exp.Ceil),
11891195
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
@@ -8224,3 +8230,16 @@ def _parse_format_name(self) -> exp.Property:
82248230
this=exp.var("FORMAT_NAME"),
82258231
value=self._parse_string() or self._parse_table_parts(),
82268232
)
8233+
8234+
def _parse_max_min_by(self, expr_type: t.Type[exp.AggFunc]) -> exp.AggFunc:
8235+
args: t.List[exp.Expression] = []
8236+
8237+
if self._match(TokenType.DISTINCT):
8238+
args.append(self.expression(exp.Distinct, expressions=[self._parse_assignment()]))
8239+
self._match(TokenType.COMMA)
8240+
8241+
args.extend(self._parse_csv(self._parse_assignment))
8242+
8243+
return self.expression(
8244+
expr_type, this=seq_get(args, 0), expression=seq_get(args, 1), count=seq_get(args, 2)
8245+
)

tests/dialects/test_snowflake.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2578,3 +2578,11 @@ def test_parameter(self):
25782578
self.assertEqual(expr.find(exp.Placeholder), exp.Placeholder(this="1"))
25792579
self.validate_identity("SELECT :1, :2")
25802580
self.validate_identity("SELECT :1 + :2")
2581+
2582+
def test_max_by_min_by(self):
2583+
max_by = self.validate_identity("MAX_BY(DISTINCT selected_col, filtered_col)")
2584+
min_by = self.validate_identity("MIN_BY(DISTINCT selected_col, filtered_col)")
2585+
2586+
for node in (max_by, min_by):
2587+
self.assertEqual(len(node.this.expressions), 1)
2588+
self.assertIsInstance(node.expression, exp.Column)

tests/fixtures/identity.sql

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,3 +889,5 @@ CAST(x AS INT128)
889889
CAST(x AS UINT128)
890890
CAST(x AS UINT256)
891891
SELECT export
892+
SELECT ARG_MAX(DISTINCT selected_col, filtered_col) FROM table
893+
SELECT ARG_MIN(DISTINCT selected_col, filtered_col) FROM table

0 commit comments

Comments
 (0)