Skip to content

Commit c5fdf7a

Browse files
SNOW-2268220: support ORDER BY ALL (#3912)
1 parent e52c58b commit c5fdf7a

File tree

10 files changed

+484
-75
lines changed

10 files changed

+484
-75
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#### Improvements
1313

14+
- Enhanced `DataFrame.sort()` to support `ORDER BY ALL` when no columns are specified.
1415
- Catalog API now uses SQL commands instead of SnowAPI calls. This new implementation is more reliable now.
1516

1617
#### Dependency Updates
@@ -103,6 +104,9 @@
103104
- `st_geometryfromwkt`
104105
- `try_to_geography`
105106
- `try_to_geometry`
107+
108+
#### Improvements
109+
106110
- Added a parameter to enable and disable automatic column name aliasing for `interval_day_time_from_parts` and `interval_year_month_from_parts` functions.
107111

108112
#### Bug Fixes

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@
114114
SnowflakeTable,
115115
SnowflakeValues,
116116
)
117-
from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder
117+
from snowflake.snowpark._internal.analyzer.sort_expression import (
118+
SortOrder,
119+
SortByAllOrder,
120+
)
118121
from snowflake.snowpark._internal.analyzer.table_function import (
119122
FlattenFunction,
120123
GeneratorTableFunction,
@@ -558,6 +561,13 @@ def analyze(
558561
expr.null_ordering.sql,
559562
)
560563

564+
if isinstance(expr, SortByAllOrder):
565+
return order_expression(
566+
"ALL",
567+
expr.direction.sql,
568+
expr.null_ordering.sql,
569+
)
570+
561571
if isinstance(expr, ScalarSubquery):
562572
self.subquery_plans.append(expr.plan)
563573
return subquery_expression(expr.plan.queries[-1].sql)

src/snowflake/snowpark/_internal/analyzer/sort_expression.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,23 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
5959

6060
def dependent_column_names_with_duplication(self) -> List[str]:
6161
return derive_dependent_columns_with_duplication(self.child)
62+
63+
64+
class SortByAllOrder(Expression):
65+
def __init__(
66+
self,
67+
direction: SortDirection,
68+
null_ordering: Optional[NullOrdering] = None,
69+
) -> None:
70+
super().__init__()
71+
self.child: Expression
72+
self.direction = direction
73+
self.null_ordering = (
74+
null_ordering if null_ordering else direction.default_null_ordering
75+
)
76+
77+
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
78+
return derive_dependent_columns(self.child)
79+
80+
def dependent_column_names_with_duplication(self) -> List[str]:
81+
return derive_dependent_columns_with_duplication(self.child)

src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
1818
DynamicTableCreateMode,
1919
)
20-
from snowflake.snowpark._internal.analyzer.sort_expression import SortOrder
20+
from snowflake.snowpark._internal.analyzer.sort_expression import (
21+
SortOrder,
22+
SortByAllOrder,
23+
)
2124

2225

2326
class UnaryNode(LogicalPlan):
@@ -90,7 +93,7 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
9093
class Sort(UnaryNode):
9194
def __init__(
9295
self,
93-
order: List[SortOrder],
96+
order: Union[List[SortOrder], List[SortByAllOrder]],
9497
child: LogicalPlan,
9598
is_order_by_append: bool = False,
9699
) -> None:

src/snowflake/snowpark/dataframe.py

Lines changed: 111 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
Ascending,
8585
Descending,
8686
SortOrder,
87+
SortByAllOrder,
8788
)
8889
from snowflake.snowpark._internal.analyzer.table_function import (
8990
FlattenFunction,
@@ -2103,6 +2104,8 @@ def sort(
21032104
) -> "DataFrame":
21042105
"""Sorts a DataFrame by the specified expressions (similar to ORDER BY in SQL).
21052106
2107+
When called with no column arguments, sorts by all columns (ORDER BY ALL).
2108+
21062109
Examples::
21072110
21082111
>>> from snowflake.snowpark.functions import col
@@ -2139,22 +2142,58 @@ def sort(
21392142
-------------
21402143
<BLANKLINE>
21412144
2145+
>>> # Sort by all columns (ORDER BY ALL) - no columns specified
2146+
>>> df.sort().show()
2147+
-------------
2148+
|"A" |"B" |
2149+
-------------
2150+
|1 |2 |
2151+
|1 |4 |
2152+
|3 |4 |
2153+
-------------
2154+
<BLANKLINE>
2155+
2156+
>>> df.sort(ascending=False).show()
2157+
-------------
2158+
|"A" |"B" |
2159+
-------------
2160+
|3 |4 |
2161+
|1 |4 |
2162+
|1 |2 |
2163+
-------------
2164+
<BLANKLINE>
2165+
21422166
Args:
2143-
*cols: A column name as :class:`str` or :class:`Column`, or a list of
2144-
columns to sort by.
2145-
ascending: A :class:`bool` or a list of :class:`bool` for sorting the
2146-
DataFrame, where ``True`` sorts a column in ascending order and ``False``
2147-
sorts a column in descending order . If you specify a list of multiple
2148-
sort orders, the length of the list must equal the number of columns.
2167+
*cols: Column names as :class:`str`, :class:`Column` objects, or a list of
2168+
columns to sort by. If no columns are provided, the DataFrame is sorted
2169+
by all columns in the order they appear (equivalent to ``ORDER BY ALL`` in SQL).
2170+
ascending: Sort order specification.
2171+
2172+
- When sorting **specific columns**: A :class:`bool`, :class:`int`, or list of
2173+
:class:`bool`/:class:`int` values. ``True`` (or 1) for ascending, ``False``
2174+
(or 0) for descending. If a list is provided, its length must match the number
2175+
of columns.
2176+
- When sorting **all columns** (no columns specified): Must be a single
2177+
:class:`bool` or :class:`int`, not a list. Applies the same sort order to
2178+
all columns.
2179+
- Defaults to ``True`` (ascending) when not specified.
2180+
2181+
Note:
2182+
The aliases ``order_by()`` and ``orderBy()`` have the same behavior.
21492183
"""
2150-
if not cols:
2151-
raise ValueError("sort() needs at least one sort expression.")
2152-
# This code performs additional type checks, run first.
2153-
exprs = self._convert_cols_to_exprs("sort()", *cols)
2154-
if not exprs:
2155-
raise ValueError("sort() needs at least one sort expression.")
21562184

2157-
# AST.
2185+
is_order_by_all = not cols
2186+
if (
2187+
is_order_by_all
2188+
and ascending is not None
2189+
and not isinstance(ascending, (bool, int))
2190+
):
2191+
raise TypeError(
2192+
"When no columns are specified (ORDER BY ALL), "
2193+
"ascending must be bool or int, not a list. "
2194+
"To sort specific columns with different orders, specify the columns."
2195+
)
2196+
21582197
stmt = None
21592198
if _emit_ast:
21602199
stmt = self._session._ast_batch.bind()
@@ -2167,59 +2206,68 @@ def sort(
21672206
ast.cols.variadic = is_variadic
21682207
self._set_ast_ref(ast.df)
21692208

2170-
orders = []
2171-
# `ascending` is represented by Expr in the AST.
2172-
# Therefore, construct the required bool, int, or list and copy from that.
2173-
asc_expr_ast = None
2174-
if _emit_ast:
2209+
# Populate ascending as Expr
21752210
asc_expr_ast = proto.Expr()
2176-
if ascending is not None:
2177-
if isinstance(ascending, (list, tuple)):
2178-
orders = [Ascending() if asc else Descending() for asc in ascending]
2179-
if _emit_ast:
2180-
# Here asc_expr_ast is a list of bools and ints.
2181-
for asc in ascending:
2182-
asc_ast = proto.Expr()
2183-
if isinstance(asc, bool):
2184-
asc_ast.bool_val.v = asc
2185-
else:
2186-
asc_ast.int64_val.v = asc
2187-
asc_expr_ast.list_val.vs.append(asc_ast)
2188-
elif isinstance(ascending, (bool, int)):
2189-
orders = [Ascending() if ascending else Descending()]
2190-
if _emit_ast:
2191-
# Here asc_expr_ast is either a bool or an int.
2192-
if isinstance(ascending, bool):
2193-
asc_expr_ast.bool_val.v = ascending
2211+
asc_value = True if ascending is None else ascending
2212+
if isinstance(asc_value, (list, tuple)):
2213+
for asc in asc_value:
2214+
asc_ast = proto.Expr()
2215+
if isinstance(asc, bool):
2216+
asc_ast.bool_val.v = asc
21942217
else:
2195-
asc_expr_ast.int64_val.v = ascending
2196-
else:
2197-
raise TypeError(
2198-
"ascending can only be boolean or list,"
2199-
" but got {}".format(str(type(ascending)))
2200-
)
2201-
if _emit_ast:
2202-
ast.ascending.CopyFrom(asc_expr_ast)
2203-
if len(exprs) != len(orders):
2204-
raise ValueError(
2205-
"The length of col ({}) should be same with"
2206-
" the length of ascending ({}).".format(len(exprs), len(orders))
2207-
)
2218+
asc_ast.int64_val.v = asc
2219+
asc_expr_ast.list_val.vs.append(asc_ast)
2220+
elif isinstance(asc_value, (bool, int)):
2221+
if isinstance(asc_value, bool):
2222+
asc_expr_ast.bool_val.v = asc_value
2223+
else:
2224+
asc_expr_ast.int64_val.v = asc_value
2225+
ast.ascending.CopyFrom(asc_expr_ast)
2226+
2227+
# Build sort expressions
2228+
if is_order_by_all:
2229+
asc_value = True if ascending is None else ascending
2230+
order = Ascending() if bool(asc_value) else Descending()
2231+
sort_exprs = [SortByAllOrder(order)]
2232+
else:
2233+
exprs = self._convert_cols_to_exprs("sort()", *cols)
2234+
if not exprs:
2235+
raise ValueError("sort() needs at least one sort expression.")
2236+
2237+
orders = []
2238+
if ascending is not None:
2239+
if isinstance(ascending, (list, tuple)):
2240+
orders = [Ascending() if asc else Descending() for asc in ascending]
2241+
elif isinstance(ascending, (bool, int)):
2242+
orders = [Ascending() if ascending else Descending()]
2243+
else:
2244+
raise TypeError(
2245+
"ascending can only be boolean or list,"
2246+
" but got {}".format(str(type(ascending)))
2247+
)
22082248

2209-
sort_exprs = []
2210-
for idx in range(len(exprs)):
2211-
# orders will overwrite current orders in expression (but will not overwrite null ordering)
2212-
# if no order is provided, use ascending order
2213-
if isinstance(exprs[idx], SortOrder):
2214-
sort_exprs.append(
2215-
SortOrder(exprs[idx].child, orders[idx], exprs[idx].null_ordering)
2216-
if orders
2217-
else exprs[idx]
2218-
)
2219-
else:
2220-
sort_exprs.append(
2221-
SortOrder(exprs[idx], orders[idx] if orders else Ascending())
2222-
)
2249+
if len(exprs) != len(orders):
2250+
raise ValueError(
2251+
"The length of col ({}) should be same with"
2252+
" the length of ascending ({}).".format(len(exprs), len(orders))
2253+
)
2254+
2255+
sort_exprs = []
2256+
for idx in range(len(exprs)):
2257+
# orders will overwrite current orders in expression (but will not overwrite null ordering)
2258+
# if no order is provided, use ascending order
2259+
if isinstance(exprs[idx], SortOrder):
2260+
sort_exprs.append(
2261+
SortOrder(
2262+
exprs[idx].child, orders[idx], exprs[idx].null_ordering
2263+
)
2264+
if orders
2265+
else exprs[idx]
2266+
)
2267+
else:
2268+
sort_exprs.append(
2269+
SortOrder(exprs[idx], orders[idx] if orders else Ascending())
2270+
)
22232271

22242272
# In snowpark_connect_compatible mode, we need to handle
22252273
# the sorting for dataframe after aggregation without nesting

0 commit comments

Comments
 (0)