Skip to content

Commit df33fc9

Browse files
SNOW-807303 Use SELECT * EXCLUDE for DataFrame.drop() (#3316)
Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com>
1 parent 0fc87b3 commit df33fc9

File tree

10 files changed

+432
-94
lines changed

10 files changed

+432
-94
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
- Invoking snowflake system procedures does not invoke an additional `describe procedure` call to check the return type of the procedure.
1010
- Added support for `Session.create_dataframe()` with the stage URL and FILE data type.
11+
- Improved query generation for `Dataframe.drop` to use `SELECT * EXCLUDE ()` to exclude the dropped columns. To enable this feature, set `session.conf.set("use_simplified_query_generation", True)`.
1112

1213
#### Bug Fixes
1314

src/snowflake/snowpark/_internal/analyzer/metadata_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def infer_metadata(
171171
]
172172
# When source_plan doesn't have a projection, it's a simple `SELECT * from ...`,
173173
# which means source_plan has the same metadata as its child plan, we can use it directly
174-
if source_plan.projection is None:
174+
if not source_plan.has_projection:
175175
# We can only retrieve the cached metadata when there is an underlying SnowflakePlan
176176
# or it's a SelectableEntity
177177
if source_plan.from_._snowflake_plan is not None:
@@ -231,7 +231,7 @@ def cache_metadata_if_selectable(
231231
# we should cache it on the child plan too.
232232
# This is necessary SelectStatement.select() will need the column states of the child plan
233233
# (check the implementation of derive_column_states_from_subquery().
234-
if source_plan.projection is None:
234+
if not source_plan.has_projection:
235235
if source_plan.from_._snowflake_plan is not None:
236236
source_plan.from_._snowflake_plan._metadata = metadata
237237
elif isinstance(source_plan.from_, SelectableEntity):

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,7 @@ def __init__(
718718
analyzer: "Analyzer",
719719
schema_query: Optional[str] = None,
720720
distinct: bool = False,
721+
exclude_cols: Optional[Set[str]] = None,
721722
) -> None:
722723
super().__init__(analyzer)
723724
self.projection: Optional[List[Expression]] = projection
@@ -731,6 +732,8 @@ def __init__(
731732
self._sql_query = None
732733
self._schema_query = schema_query
733734
self.distinct_: bool = distinct
735+
# An optional set to store the columns that should be excluded from the projection
736+
self.exclude_cols: Optional[Set[str]] = exclude_cols
734737
self._projection_in_str = None
735738
self._query_params = None
736739
self.expr_to_alias.update(self.from_.expr_to_alias)
@@ -772,6 +775,7 @@ def __copy__(self):
772775
analyzer=self.analyzer,
773776
schema_query=self.schema_query,
774777
distinct=self.distinct_,
778+
exclude_cols=self.exclude_cols,
775779
)
776780
# The following values will change if they're None in the newly copied one so reset their values here
777781
# to avoid problems.
@@ -802,6 +806,7 @@ def __deepcopy__(self, memodict={}) -> "SelectStatement": # noqa: B006
802806
# directly copy the current schema fields
803807
schema_query=self._schema_query,
804808
distinct=self.distinct_,
809+
exclude_cols=self.exclude_cols,
805810
)
806811

807812
_deepcopy_selectable_fields(from_selectable=self, to_selectable=copied)
@@ -820,7 +825,7 @@ def __deepcopy__(self, memodict={}) -> "SelectStatement": # noqa: B006
820825
@property
821826
def column_states(self) -> ColumnStateDict:
822827
if self._column_states is None:
823-
if not self.projection and not self.has_clause:
828+
if not self.has_projection and not self.has_clause:
824829
self.column_states = self.from_.column_states
825830
else:
826831
super().column_states # will assign value to self._column_states
@@ -850,24 +855,44 @@ def has_clause(self) -> bool:
850855
self.has_clause_using_columns or self.limit_ is not None or self.distinct_
851856
)
852857

858+
@property
859+
def has_projection(self) -> bool:
860+
"""Boolean that indicates if the SelectStatement has the following forms of projection:
861+
- select columns
862+
- exclude columns
863+
"""
864+
865+
return (
866+
self.projection is not None and len(self.projection) > 0
867+
) or self.exclude_cols is not None
868+
853869
@property
854870
def projection_in_str(self) -> str:
855871
if not self._projection_in_str:
856-
self._projection_in_str = (
857-
analyzer_utils.COMMA.join(
872+
if self.projection:
873+
assert (
874+
self.exclude_cols is None
875+
), "We should not have reached this state. There is likely a bug in flattening logic."
876+
self._projection_in_str = analyzer_utils.COMMA.join(
858877
self.analyzer.analyze(x, self.df_aliased_col_name_to_real_col_name)
859878
for x in self.projection
860879
)
861-
if self.projection
862-
else analyzer_utils.STAR
863-
)
880+
else:
881+
self._projection_in_str = analyzer_utils.STAR
882+
if self.exclude_cols is not None:
883+
# we sort the exclude_cols to make sure the projection_in_str is deterministic
884+
# this is done to remove test flakiness
885+
self._projection_in_str = (
886+
f"{analyzer_utils.STAR}{analyzer_utils.EXCLUDE}"
887+
f"({analyzer_utils.COMMA.join(sorted(self.exclude_cols))})"
888+
)
864889
return self._projection_in_str
865890

866891
@property
867892
def sql_query(self) -> str:
868893
if self._sql_query:
869894
return self._sql_query
870-
if not self.has_clause and not self.projection:
895+
if not self.has_clause and not self.has_projection:
871896
self._sql_query = self.from_.sql_query
872897
return self._sql_query
873898
from_clause = self.from_.sql_in_subquery
@@ -920,7 +945,7 @@ def attributes(self, value: Optional[List[Attribute]]):
920945
def schema_query(self) -> str:
921946
if self._schema_query:
922947
return self._schema_query
923-
if not self.projection:
948+
if not self.has_projection:
924949
self._schema_query = self.from_.schema_query
925950
return self._schema_query
926951
self._schema_query = f"{analyzer_utils.SELECT}{self.projection_in_str}{analyzer_utils.FROM}({self.from_.schema_query})"
@@ -1180,6 +1205,9 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
11801205
elif self.distinct_:
11811206
# .distinct().select() != .select().distinct() therefore we cannot flatten
11821207
can_be_flattened = False
1208+
elif self.exclude_cols is not None:
1209+
# exclude syntax only support: SELECT * EXCLUDE(col1, col2) FROM TABLE
1210+
can_be_flattened = False
11831211
else:
11841212
can_be_flattened = can_select_statement_be_flattened(
11851213
self.column_states, new_column_states
@@ -1303,7 +1331,7 @@ def distinct(self) -> "SelectStatement":
13031331
and (not self.offset)
13041332
# .order_by(col1).select(col2).distinct() cannot be flattened because
13051333
# SELECT DISTINCT B FROM TABLE ORDER BY A is not valid SQL
1306-
and (not (self.order_by and self.projection))
1334+
and (not (self.order_by and self.has_projection))
13071335
and not has_data_generator_exp(self.projection)
13081336
)
13091337
if can_be_flattened:
@@ -1325,6 +1353,42 @@ def distinct(self) -> "SelectStatement":
13251353
new.attributes = self.attributes
13261354
return new
13271355

1356+
def exclude(
1357+
self, exclude_cols: List[str], keep_cols: List[str]
1358+
) -> "SelectStatement":
1359+
"""List of quoted column names to be dropped from the current select
1360+
statement.
1361+
"""
1362+
# .select().drop(); cannot be flattened; exclude syntax is select * exclude ...
1363+
1364+
# .order_by().drop() can be flattened
1365+
# .filter().drop() can be flattened
1366+
# .limit().drop() can be flattened
1367+
# .distinct().drop() can be flattened
1368+
can_be_flattened = not self.flatten_disabled and not self.projection
1369+
if can_be_flattened:
1370+
new = copy(self)
1371+
new.from_ = self.from_.to_subqueryable()
1372+
new.pre_actions = new.from_.pre_actions
1373+
new.post_actions = new.from_.post_actions
1374+
new._merge_projection_complexity_with_subquery = False
1375+
else:
1376+
new = SelectStatement(
1377+
from_=self.to_subqueryable(),
1378+
analyzer=self.analyzer,
1379+
)
1380+
1381+
new.exclude_cols = new.exclude_cols or set()
1382+
new.exclude_cols.update(exclude_cols)
1383+
1384+
# Use keep_cols and select logic to derive updated column_states for new
1385+
new_column_states = derive_column_states_from_subquery(
1386+
[Attribute(col, DataType()) for col in keep_cols], self
1387+
)
1388+
assert new_column_states is not None
1389+
new.column_states = new_column_states
1390+
return new
1391+
13281392
def set_operator(
13291393
self,
13301394
*selectables: Union[
@@ -1336,7 +1400,7 @@ def set_operator(
13361400
if (
13371401
isinstance(self.from_, SetStatement)
13381402
and not self.has_clause
1339-
and not self.projection
1403+
and not self.has_projection
13401404
):
13411405
last_operator = self.from_.set_operands[-1].operator
13421406
if operator == last_operator:

src/snowflake/snowpark/_internal/telemetry.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ class TelemetryField(Enum):
123123
API_CALLS_TO_ADJUST = {
124124
"to_df": 1,
125125
"select_expr": 1,
126-
"drop": 1,
127126
"agg": 2,
128127
"with_column": 1,
129128
"with_columns": 1,

src/snowflake/snowpark/dataframe.py

Lines changed: 67 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,7 +1736,6 @@ def select_expr(
17361736

17371737
selectExpr = select_expr
17381738

1739-
@df_api_usage
17401739
@publicapi
17411740
def drop(
17421741
self, *cols: Union[ColumnOrName, Iterable[ColumnOrName]], _emit_ast: bool = True
@@ -1782,48 +1781,83 @@ def drop(
17821781
build_expr_from_snowpark_column_or_col_name(ast.cols.args.add(), c)
17831782
ast.cols.variadic = is_variadic
17841783

1785-
names = []
1786-
for c in exprs:
1787-
if isinstance(c, str):
1788-
names.append(c)
1789-
elif isinstance(c, Column) and isinstance(c._expression, Attribute):
1790-
from snowflake.snowpark.mock._connection import MockServerConnection
1784+
with ResourceUsageCollector() as resource_usage_collector:
1785+
names = []
1786+
for c in exprs:
1787+
if isinstance(c, str):
1788+
names.append(c)
1789+
elif isinstance(c, Column) and isinstance(c._expression, Attribute):
1790+
from snowflake.snowpark.mock._connection import MockServerConnection
17911791

1792-
if isinstance(self._session._conn, MockServerConnection):
1793-
self.schema # to execute the plan and populate expr_to_alias
1792+
if isinstance(self._session._conn, MockServerConnection):
1793+
self.schema # to execute the plan and populate expr_to_alias
17941794

1795-
names.append(
1796-
self._plan.expr_to_alias.get(
1797-
c._expression.expr_id, c._expression.name
1795+
names.append(
1796+
self._plan.expr_to_alias.get(
1797+
c._expression.expr_id, c._expression.name
1798+
)
17981799
)
1799-
)
1800-
elif (
1801-
isinstance(c, Column)
1802-
and isinstance(c._expression, UnresolvedAttribute)
1803-
and c._expression.df_alias
1800+
elif (
1801+
isinstance(c, Column)
1802+
and isinstance(c._expression, UnresolvedAttribute)
1803+
and c._expression.df_alias
1804+
):
1805+
names.append(
1806+
self._plan.df_aliased_col_name_to_real_col_name.get(
1807+
c._expression.name, c._expression.name
1808+
)
1809+
)
1810+
elif isinstance(c, Column) and isinstance(
1811+
c._expression, NamedExpression
1812+
):
1813+
names.append(c._expression.name)
1814+
else:
1815+
raise SnowparkClientExceptionMessages.DF_CANNOT_DROP_COLUMN_NAME(
1816+
str(c)
1817+
)
1818+
1819+
normalized_names = {quote_name(n) for n in names}
1820+
existing_names = [attr.name for attr in self._output]
1821+
keep_col_names = [c for c in existing_names if c not in normalized_names]
1822+
if not keep_col_names:
1823+
raise SnowparkClientExceptionMessages.DF_CANNOT_DROP_ALL_COLUMNS()
1824+
1825+
if self._select_statement and self._session.conf.get(
1826+
"use_simplified_query_generation"
18041827
):
1805-
names.append(
1806-
self._plan.df_aliased_col_name_to_real_col_name.get(
1807-
c._expression.name, c._expression.name
1828+
# Only drop the columns that exist in the DataFrame.
1829+
drop_normalized_names = [
1830+
name for name in normalized_names if name in existing_names
1831+
]
1832+
if not drop_normalized_names:
1833+
df = self._with_plan(self._select_statement)
1834+
else:
1835+
df = self._with_plan(
1836+
self._select_statement.exclude(
1837+
drop_normalized_names, keep_col_names
1838+
)
18081839
)
1809-
)
1810-
elif isinstance(c, Column) and isinstance(c._expression, NamedExpression):
1811-
names.append(c._expression.name)
18121840
else:
1813-
raise SnowparkClientExceptionMessages.DF_CANNOT_DROP_COLUMN_NAME(str(c))
1841+
df = self.select(list(keep_col_names), _emit_ast=False)
18141842

1815-
normalized_names = {quote_name(n) for n in names}
1816-
existing_names = [attr.name for attr in self._output]
1817-
keep_col_names = [c for c in existing_names if c not in normalized_names]
1818-
if not keep_col_names:
1819-
raise SnowparkClientExceptionMessages.DF_CANNOT_DROP_ALL_COLUMNS()
1843+
if self._session.conf.get("use_simplified_query_generation"):
1844+
add_api_call(
1845+
df,
1846+
"DataFrame.drop[exclude]",
1847+
resource_usage_collector.get_resource_usage(),
1848+
)
18201849
else:
1821-
df = self.select(list(keep_col_names), _emit_ast=False)
1850+
adjust_api_subcalls(
1851+
df,
1852+
"DataFrame.drop[select]",
1853+
len_subcalls=1,
1854+
resource_usage=resource_usage_collector.get_resource_usage(),
1855+
)
18221856

1823-
if _emit_ast:
1824-
df._ast_id = stmt.uid
1857+
if _emit_ast:
1858+
df._ast_id = stmt.uid
18251859

1826-
return df
1860+
return df
18271861

18281862
@df_api_usage
18291863
@publicapi

src/snowflake/snowpark/mock/_nop_analyzer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def sort(self, cols: List[Expression]) -> "MockSelectStatement":
5959
def distinct(self) -> "MockSelectStatement":
6060
return self._make_nop_select_statement_copy(super().distinct())
6161

62+
def exclude(self, exclude_cols, keep_cols) -> "MockSelectStatement":
63+
return super().exclude(exclude_cols, keep_cols)
64+
6265
def set_operator(
6366
self,
6467
*selectables: Union[

src/snowflake/snowpark/mock/_plan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,7 @@ def execute_mock_plan(
10521052
limit_: Optional[int] = source_plan.limit_
10531053
offset: Optional[int] = source_plan.offset
10541054
distinct_: bool = source_plan.distinct_
1055+
exclude_cols: List[str] = source_plan.exclude_cols
10551056

10561057
from_df = execute_mock_plan(from_, expr_to_alias)
10571058

@@ -1137,6 +1138,9 @@ def execute_mock_plan(
11371138
if distinct_:
11381139
result_df = result_df.drop_duplicates()
11391140

1141+
if exclude_cols:
1142+
result_df = result_df.drop(columns=exclude_cols)
1143+
11401144
return result_df
11411145
if isinstance(source_plan, MockSetStatement):
11421146
first_operand = source_plan.set_operands[0]

0 commit comments

Comments
 (0)