Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#### New Features

- Added support for targeted delete-insert via the `overwrite_condition` parameter in `DataFrameWriter.save_as_table`

#### Bug Fixes

#### Improvements
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,12 @@ def do_resolve_with_resolved_children(
child_attributes=resolved_child.attributes,
iceberg_config=iceberg_config,
table_exists=logical_plan.table_exists,
overwrite_condition=self.analyze(
logical_plan.overwrite_condition,
df_aliased_col_name_to_real_col_name,
)
if logical_plan.overwrite_condition
else None,
)

if isinstance(logical_plan, Limit):
Expand Down
72 changes: 62 additions & 10 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,7 @@ def save_as_table(
child_attributes: Optional[List[Attribute]],
iceberg_config: Optional[dict] = None,
table_exists: Optional[bool] = None,
overwrite_condition: Optional[str] = None,
) -> SnowflakePlan:
"""Returns a SnowflakePlan to materialize the child plan into a table.

Expand Down Expand Up @@ -1417,19 +1418,65 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace, error):
referenced_ctes=child.referenced_ctes,
)

def get_overwrite_delete_insert_plan(child: SnowflakePlan):
"""Build a plan for targeted delete + insert with transaction.

Deletes rows matching the overwrite_condition condition, then inserts
all rows from the source DataFrame. Wrapped in a transaction for atomicity.
"""
child = self.add_result_scan_if_not_select(child)

return SnowflakePlan(
[
*child.queries[:-1],
Query("BEGIN TRANSACTION"),
Query(
delete_statement(
table_name=full_table_name,
condition=overwrite_condition,
source_data=None,
),
params=child.queries[-1].params,
is_ddl_on_temp_object=is_temp_table_type,
),
Query(
insert_into_statement(
table_name=full_table_name,
child=child.queries[-1].sql,
column_names=column_names,
),
params=child.queries[-1].params,
is_ddl_on_temp_object=is_temp_table_type,
),
Query("COMMIT"),
],
schema_query=None,
post_actions=child.post_actions,
expr_to_alias={},
source_plan=source_plan,
api_calls=child.api_calls,
session=self.session,
referenced_ctes=child.referenced_ctes,
)

if mode == SaveMode.APPEND:
assert table_exists is not None
if table_exists:
return self.build(
lambda x: insert_into_statement(
table_name=full_table_name,
child=x,
column_names=column_names,
),
child,
source_plan,
)
if overwrite_condition is not None:
return get_overwrite_delete_insert_plan(child)
else:
# Normal append without overwrite_condition
return self.build(
lambda x: insert_into_statement(
table_name=full_table_name,
child=x,
column_names=column_names,
),
child,
source_plan,
)
else:
# Table doesn't exist, just create and insert (overwrite_condition is no-op)
return get_create_and_insert_plan(child, replace=False, error=False)

elif mode == SaveMode.TRUNCATE:
Expand All @@ -1446,7 +1493,12 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace, error):
return get_create_table_as_select_plan(child, replace=True, error=True)

elif mode == SaveMode.OVERWRITE:
return get_create_table_as_select_plan(child, replace=True, error=True)
if overwrite_condition is not None and table_exists:
# Selective overwrite: delete matching rows, then insert
return get_overwrite_delete_insert_plan(child)
else:
# Default overwrite: drop and recreate table
return get_create_table_as_select_plan(child, replace=True, error=True)

elif mode == SaveMode.IGNORE:
return get_create_table_as_select_plan(child, replace=False, error=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __init__(
copy_grants: bool = False,
iceberg_config: Optional[dict] = None,
table_exists: Optional[bool] = None,
overwrite_condition: Optional[Expression] = None,
) -> None:
super().__init__()

Expand All @@ -267,6 +268,7 @@ def __init__(
# whether the table already exists in the database
# determines the compiled SQL for APPEND and TRUNCATE mode
self.table_exists = table_exists
self.overwrite_condition = overwrite_condition

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
Expand Down
54 changes: 46 additions & 8 deletions src/snowflake/snowpark/dataframe_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
warning,
)
from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType
from snowflake.snowpark.column import Column, _to_col_if_str
from snowflake.snowpark.column import Column, _to_col_if_str, _to_col_if_sql_expr
from snowflake.snowpark.exceptions import SnowparkClientException
from snowflake.snowpark.functions import sql_expr
from snowflake.snowpark.mock._connection import MockServerConnection
Expand Down Expand Up @@ -256,6 +256,7 @@ def save_as_table(
Dict[str, Union[str, Iterable[ColumnOrSqlExpr]]]
] = None,
table_exists: Optional[bool] = None,
overwrite_condition: Optional[ColumnOrSqlExpr] = None,
_emit_ast: bool = True,
**kwargs: Optional[Dict[str, Any]],
) -> Optional[AsyncJob]:
Expand Down Expand Up @@ -330,7 +331,10 @@ def save_as_table(
* iceberg_version: Overrides the version of iceberg to use. Defaults to 2 when unset.
table_exists: Optional parameter to specify if the table is known to exist or not.
Set to ``True`` if table exists, ``False`` if it doesn't, or ``None`` (default) for automatic detection.
Primarily useful for "append" and "truncate" modes to avoid running query for automatic detection.
Primarily useful for "append", "truncate", and "overwrite" with overwrite_condition modes to avoid running query for automatic detection.
overwrite_condition: Specifies the overwrite condition to perform atomic targeted delete-insert.
Can be used when ``mode`` is "append" or "overwrite" when the table exists. Rows matching the
condition are deleted from the target table, then all rows from the DataFrame are inserted.


Example 1::
Expand Down Expand Up @@ -364,6 +368,21 @@ def save_as_table(
... "partition_by": ["a", bucket(3, col("b"))],
... }
>>> df.write.mode("overwrite").save_as_table("my_table", iceberg_config=iceberg_config) # doctest: +SKIP

Example 3::

Using overwrite_condition for targeted delete and insert:

>>> from snowflake.snowpark.functions import col
>>> df = session.create_dataframe([[1, "a"], [2, "b"], [3, "c"]], schema=["id", "val"])
>>> df.write.mode("overwrite").save_as_table("my_table", table_type="temporary")
>>> session.table("my_table").order_by("id").collect()
[Row(ID=1, VAL='a'), Row(ID=2, VAL='b'), Row(ID=3, VAL='c')]

>>> new_df = session.create_dataframe([[2, "updated2"], [5, "updated5"]], schema=["id", "val"])
>>> new_df.write.mode("append").save_as_table("my_table", overwrite_condition="id = 1 or val = 'b'")
>>> session.table("my_table").order_by("id").collect()
[Row(ID=2, VAL='updated2'), Row(ID=3, VAL='c'), Row(ID=5, VAL='updated5')]
"""

statement_params = track_data_source_statement_params(
Expand Down Expand Up @@ -486,18 +505,36 @@ def save_as_table(
f"Unsupported table type. Expected table types: {SUPPORTED_TABLE_TYPES}"
)

# overwrite_condition must be used with APPEND or OVERWRITE mode
if overwrite_condition is not None and save_mode not in (
SaveMode.APPEND,
SaveMode.OVERWRITE,
):
raise ValueError(
f"'overwrite_condition' is only supported with mode='append' or mode='overwrite'. "
f"Got mode='{save_mode.value}'."
)

overwrite_condition_expr = (
_to_col_if_sql_expr(
overwrite_condition, "DataFrameWriter.save_as_table"
)._expression
if overwrite_condition is not None
else None
)

session = self._dataframe._session
needs_table_exists_check = save_mode in [
SaveMode.APPEND,
SaveMode.TRUNCATE,
] or (save_mode == SaveMode.OVERWRITE and overwrite_condition is not None)
if (
table_exists is None
and not isinstance(session._conn, MockServerConnection)
and save_mode
in [
SaveMode.APPEND,
SaveMode.TRUNCATE,
]
and needs_table_exists_check
):
# whether the table already exists in the database
# determines the compiled SQL for APPEND and TRUNCATE mode
# determines the compiled SQL for APPEND, TRUNCATE, and OVERWRITE with overwrite_condition
# if the table does not exist, we need to create it first;
# if the table exists, we can skip the creation step and insert data directly
table_exists = session._table_exists(table_name)
Expand All @@ -518,6 +555,7 @@ def save_as_table(
copy_grants,
iceberg_config,
table_exists,
overwrite_condition_expr,
)
snowflake_plan = session._analyzer.resolve(create_table_logic_plan)
result = session._conn.execute(
Expand Down
Loading
Loading