Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

- Invoking snowflake system procedures does not invoke an additional `describe procedure` call to check the return type of the procedure.
- Added support for `Session.create_dataframe()` with the stage URL and FILE data type.
- Improved query generation for `Dataframe.drop` to use `SELECT * EXCLUDE ()` to exclude the dropped columns. To enable this feature, set `session.conf.set("use_simplified_query_generation", True)`.

#### Bug Fixes

Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def infer_metadata(
]
# When source_plan doesn't have a projection, it's a simple `SELECT * from ...`,
# which means source_plan has the same metadata as its child plan, we can use it directly
if source_plan.projection is None:
if not source_plan.has_projection:
# We can only retrieve the cached metadata when there is an underlying SnowflakePlan
# or it's a SelectableEntity
if source_plan.from_._snowflake_plan is not None:
Expand Down Expand Up @@ -231,7 +231,7 @@ def cache_metadata_if_selectable(
# we should cache it on the child plan too.
# This is necessary SelectStatement.select() will need the column states of the child plan
# (check the implementation of derive_column_states_from_subquery().
if source_plan.projection is None:
if not source_plan.has_projection:
if source_plan.from_._snowflake_plan is not None:
source_plan.from_._snowflake_plan._metadata = metadata
elif isinstance(source_plan.from_, SelectableEntity):
Expand Down
76 changes: 66 additions & 10 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ def __init__(
analyzer: "Analyzer",
schema_query: Optional[str] = None,
distinct: bool = False,
exclude_cols: Optional[Set[str]] = None,
) -> None:
super().__init__(analyzer)
self.projection: Optional[List[Expression]] = projection
Expand All @@ -731,6 +732,8 @@ def __init__(
self._sql_query = None
self._schema_query = schema_query
self.distinct_: bool = distinct
# An optional set to store the columns that should be excluded from the projection
self.exclude_cols: Optional[Set[str]] = exclude_cols
self._projection_in_str = None
self._query_params = None
self.expr_to_alias.update(self.from_.expr_to_alias)
Expand Down Expand Up @@ -772,6 +775,7 @@ def __copy__(self):
analyzer=self.analyzer,
schema_query=self.schema_query,
distinct=self.distinct_,
exclude_cols=self.exclude_cols,
)
# The following values will change if they're None in the newly copied one so reset their values here
# to avoid problems.
Expand Down Expand Up @@ -802,6 +806,7 @@ def __deepcopy__(self, memodict={}) -> "SelectStatement": # noqa: B006
# directly copy the current schema fields
schema_query=self._schema_query,
distinct=self.distinct_,
exclude_cols=self.exclude_cols,
)

_deepcopy_selectable_fields(from_selectable=self, to_selectable=copied)
Expand All @@ -820,7 +825,7 @@ def __deepcopy__(self, memodict={}) -> "SelectStatement": # noqa: B006
@property
def column_states(self) -> ColumnStateDict:
if self._column_states is None:
if not self.projection and not self.has_clause:
if not self.has_projection and not self.has_clause:
self.column_states = self.from_.column_states
else:
super().column_states # will assign value to self._column_states
Expand Down Expand Up @@ -850,24 +855,36 @@ def has_clause(self) -> bool:
self.has_clause_using_columns or self.limit_ is not None or self.distinct_
)

@property
def has_projection(self) -> bool:
"""Boolean that indicates if the SelectStatement has the following forms of projection:
- select columns
- exclude columns
"""
return self.projection is not None or self.exclude_cols is not None

@property
def projection_in_str(self) -> str:
if not self._projection_in_str:
self._projection_in_str = (
analyzer_utils.COMMA.join(
if self.projection:
assert (
self.exclude_cols is None
), "We should not have reached this state. There is likely a bug in flattening logic."
self._projection_in_str = analyzer_utils.COMMA.join(
self.analyzer.analyze(x, self.df_aliased_col_name_to_real_col_name)
for x in self.projection
)
if self.projection
else analyzer_utils.STAR
)
else:
self._projection_in_str = analyzer_utils.STAR
if self.exclude_cols is not None:
self._projection_in_str = f"{analyzer_utils.STAR}{analyzer_utils.EXCLUDE}({analyzer_utils.COMMA.join(sorted(self.exclude_cols))})"
return self._projection_in_str

@property
def sql_query(self) -> str:
if self._sql_query:
return self._sql_query
if not self.has_clause and not self.projection:
if not self.has_clause and not self.has_projection:
self._sql_query = self.from_.sql_query
return self._sql_query
from_clause = self.from_.sql_in_subquery
Expand Down Expand Up @@ -920,7 +937,7 @@ def attributes(self, value: Optional[List[Attribute]]):
def schema_query(self) -> str:
if self._schema_query:
return self._schema_query
if not self.projection:
if not self.has_projection:
self._schema_query = self.from_.schema_query
return self._schema_query
self._schema_query = f"{analyzer_utils.SELECT}{self.projection_in_str}{analyzer_utils.FROM}({self.from_.schema_query})"
Expand Down Expand Up @@ -1180,6 +1197,9 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
elif self.distinct_:
# .distinct().select() != .select().distinct() therefore we cannot flatten
can_be_flattened = False
elif self.exclude_cols is not None:
# exclude syntax only support: SELECT * EXCLUDE(col1, col2) FROM TABLE
can_be_flattened = False
else:
can_be_flattened = can_select_statement_be_flattened(
self.column_states, new_column_states
Expand Down Expand Up @@ -1303,7 +1323,7 @@ def distinct(self) -> "SelectStatement":
and (not self.offset)
# .order_by(col1).select(col2).distinct() cannot be flattened because
# SELECT DISTINCT B FROM TABLE ORDER BY A is not valid SQL
and (not (self.order_by and self.projection))
and (not (self.order_by and self.has_projection))
and not has_data_generator_exp(self.projection)
)
if can_be_flattened:
Expand All @@ -1325,6 +1345,42 @@ def distinct(self) -> "SelectStatement":
new.attributes = self.attributes
return new

def exclude(
self, exclude_cols: List[str], keep_cols: List[str]
) -> "SelectStatement":
"""List of quoted column names to be dropped from the current select
statement.
"""
# .select().drop(); cannot be flattened; exclude syntax is select * exclude ...

# .order_by().drop() can be flattened
# .filter().drop() can be flattened
# .limit().drop() can be flattened
# .distinct().drop() can be flattened
can_be_flattened = not self.flatten_disabled and not self.projection
if can_be_flattened:
new = copy(self)
new.from_ = self.from_.to_subqueryable()
new.pre_actions = new.from_.pre_actions
new.post_actions = new.from_.post_actions
new._merge_projection_complexity_with_subquery = False
else:
new = SelectStatement(
from_=self.to_subqueryable(),
analyzer=self.analyzer,
)

new.exclude_cols = new.exclude_cols or set()
new.exclude_cols.update(exclude_cols)

# Use keep_cols and select logic to derive updated column_states for new
new_column_states = derive_column_states_from_subquery(
[Attribute(col, DataType()) for col in keep_cols], self
)
assert new_column_states is not None
new.column_states = new_column_states
return new

def set_operator(
self,
*selectables: Union[
Expand All @@ -1336,7 +1392,7 @@ def set_operator(
if (
isinstance(self.from_, SetStatement)
and not self.has_clause
and not self.projection
and not self.has_projection
):
last_operator = self.from_.set_operands[-1].operator
if operator == last_operator:
Expand Down
1 change: 0 additions & 1 deletion src/snowflake/snowpark/_internal/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ class TelemetryField(Enum):
API_CALLS_TO_ADJUST = {
"to_df": 1,
"select_expr": 1,
"drop": 1,
"agg": 2,
"distinct": 2,
"with_column": 1,
Expand Down
100 changes: 67 additions & 33 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,7 +1736,6 @@ def select_expr(

selectExpr = select_expr

@df_api_usage
@publicapi
def drop(
self, *cols: Union[ColumnOrName, Iterable[ColumnOrName]], _emit_ast: bool = True
Expand Down Expand Up @@ -1782,48 +1781,83 @@ def drop(
build_expr_from_snowpark_column_or_col_name(ast.cols.args.add(), c)
ast.cols.variadic = is_variadic

names = []
for c in exprs:
if isinstance(c, str):
names.append(c)
elif isinstance(c, Column) and isinstance(c._expression, Attribute):
from snowflake.snowpark.mock._connection import MockServerConnection
with ResourceUsageCollector() as resource_usage_collector:
names = []
for c in exprs:
if isinstance(c, str):
names.append(c)
elif isinstance(c, Column) and isinstance(c._expression, Attribute):
from snowflake.snowpark.mock._connection import MockServerConnection

if isinstance(self._session._conn, MockServerConnection):
self.schema # to execute the plan and populate expr_to_alias
if isinstance(self._session._conn, MockServerConnection):
self.schema # to execute the plan and populate expr_to_alias

names.append(
self._plan.expr_to_alias.get(
c._expression.expr_id, c._expression.name
names.append(
self._plan.expr_to_alias.get(
c._expression.expr_id, c._expression.name
)
)
)
elif (
isinstance(c, Column)
and isinstance(c._expression, UnresolvedAttribute)
and c._expression.df_alias
elif (
isinstance(c, Column)
and isinstance(c._expression, UnresolvedAttribute)
and c._expression.df_alias
):
names.append(
self._plan.df_aliased_col_name_to_real_col_name.get(
c._expression.name, c._expression.name
)
)
elif isinstance(c, Column) and isinstance(
c._expression, NamedExpression
):
names.append(c._expression.name)
else:
raise SnowparkClientExceptionMessages.DF_CANNOT_DROP_COLUMN_NAME(
str(c)
)

normalized_names = {quote_name(n) for n in names}
existing_names = [attr.name for attr in self._output]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to not check existing names if using exclude? It will trigger a describe query

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed offline, if we don't do this describe query, it may lead to a BCR since SELECT * EXCLUDE (non_existing_column) raises error SQL error while it raises no error today.

keep_col_names = [c for c in existing_names if c not in normalized_names]
if not keep_col_names:
raise SnowparkClientExceptionMessages.DF_CANNOT_DROP_ALL_COLUMNS()

if self._select_statement and self._session.conf.get(
"use_simplified_query_generation"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there other functionality that will use this flag?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. all new query generation improvement for dataframe APIs are protected with this parameter. .distinct(), .random_split(), .stat.sample_by(), .drop() and planning to add .rename() under this.

):
names.append(
self._plan.df_aliased_col_name_to_real_col_name.get(
c._expression.name, c._expression.name
# Only drop the columns that exist in the DataFrame.
drop_normalized_names = [
name for name in normalized_names if name in existing_names
]
if not drop_normalized_names:
df = self._with_plan(self._select_statement)
else:
df = self._with_plan(
self._select_statement.exclude(
drop_normalized_names, keep_col_names
)
)
)
elif isinstance(c, Column) and isinstance(c._expression, NamedExpression):
names.append(c._expression.name)
else:
raise SnowparkClientExceptionMessages.DF_CANNOT_DROP_COLUMN_NAME(str(c))
df = self.select(list(keep_col_names), _emit_ast=False)

normalized_names = {quote_name(n) for n in names}
existing_names = [attr.name for attr in self._output]
keep_col_names = [c for c in existing_names if c not in normalized_names]
if not keep_col_names:
raise SnowparkClientExceptionMessages.DF_CANNOT_DROP_ALL_COLUMNS()
if self._session.conf.get("use_simplified_query_generation"):
add_api_call(
df,
"DataFrame.drop[exclude]",
resource_usage_collector.get_resource_usage(),
)
else:
df = self.select(list(keep_col_names), _emit_ast=False)
adjust_api_subcalls(
df,
"DataFrame.drop[select]",
len_subcalls=1,
resource_usage=resource_usage_collector.get_resource_usage(),
)

if _emit_ast:
df._ast_id = stmt.uid
if _emit_ast:
df._ast_id = stmt.uid

return df
return df

@df_api_usage
@publicapi
Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/snowpark/mock/_nop_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def sort(self, cols: List[Expression]) -> "MockSelectStatement":
def distinct(self) -> "MockSelectStatement":
return self._make_nop_select_statement_copy(super().distinct())

def exclude(self, exclude_cols, keep_cols) -> "MockSelectStatement":
return super().exclude(exclude_cols, keep_cols)

def set_operator(
self,
*selectables: Union[
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,7 @@ def execute_mock_plan(
limit_: Optional[int] = source_plan.limit_
offset: Optional[int] = source_plan.offset
distinct_: bool = source_plan.distinct_
exclude_cols: List[str] = source_plan.exclude_cols

from_df = execute_mock_plan(from_, expr_to_alias)

Expand Down Expand Up @@ -1137,6 +1138,9 @@ def execute_mock_plan(
if distinct_:
result_df = result_df.drop_duplicates()

if exclude_cols:
result_df = result_df.drop(columns=exclude_cols)

return result_df
if isinstance(source_plan, MockSetStatement):
first_operand = source_plan.set_operands[0]
Expand Down
Loading
Loading