Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#### New Features

- Added support for targeted delete-insert via the `overwrite_condition` parameter in `DataFrameWriter.save_as_table`

#### Bug Fixes

#### Improvements
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,12 @@ def do_resolve_with_resolved_children(
child_attributes=resolved_child.attributes,
iceberg_config=iceberg_config,
table_exists=logical_plan.table_exists,
overwrite_condition=self.analyze(
logical_plan.overwrite_condition,
df_aliased_col_name_to_real_col_name,
)
if logical_plan.overwrite_condition
else None,
)

if isinstance(logical_plan, Limit):
Expand Down
49 changes: 48 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,7 @@ def save_as_table(
child_attributes: Optional[List[Attribute]],
iceberg_config: Optional[dict] = None,
table_exists: Optional[bool] = None,
overwrite_condition: Optional[str] = None,
) -> SnowflakePlan:
"""Returns a SnowflakePlan to materialize the child plan into a table.

Expand Down Expand Up @@ -1417,6 +1418,47 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace, error):
referenced_ctes=child.referenced_ctes,
)

def get_overwrite_delete_insert_plan(child: SnowflakePlan):
"""Build a plan for targeted delete + insert with transaction.

Deletes rows matching the overwrite_condition condition, then inserts
all rows from the source DataFrame. Wrapped in a transaction for atomicity.
"""
child = self.add_result_scan_if_not_select(child)

return SnowflakePlan(
[
*child.queries[:-1],
Query("BEGIN TRANSACTION"),
Query(
delete_statement(
table_name=full_table_name,
condition=overwrite_condition,
source_data=None,
),
params=child.queries[-1].params,
is_ddl_on_temp_object=is_temp_table_type,
),
Query(
insert_into_statement(
table_name=full_table_name,
child=child.queries[-1].sql,
column_names=column_names,
),
params=child.queries[-1].params,
is_ddl_on_temp_object=is_temp_table_type,
),
Query("COMMIT"),
],
schema_query=None,
post_actions=child.post_actions,
expr_to_alias={},
source_plan=source_plan,
api_calls=child.api_calls,
session=self.session,
referenced_ctes=child.referenced_ctes,
)

if mode == SaveMode.APPEND:
assert table_exists is not None
if table_exists:
Expand Down Expand Up @@ -1446,7 +1488,12 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace, error):
return get_create_table_as_select_plan(child, replace=True, error=True)

elif mode == SaveMode.OVERWRITE:
return get_create_table_as_select_plan(child, replace=True, error=True)
if overwrite_condition is not None and table_exists:
# Selective overwrite: delete matching rows, then insert
return get_overwrite_delete_insert_plan(child)
else:
# Default overwrite: drop and recreate table
return get_create_table_as_select_plan(child, replace=True, error=True)

elif mode == SaveMode.IGNORE:
return get_create_table_as_select_plan(child, replace=False, error=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __init__(
copy_grants: bool = False,
iceberg_config: Optional[dict] = None,
table_exists: Optional[bool] = None,
overwrite_condition: Optional[Expression] = None,
) -> None:
super().__init__()

Expand All @@ -267,6 +268,7 @@ def __init__(
# whether the table already exists in the database
# determines the compiled SQL for APPEND and TRUNCATE mode
self.table_exists = table_exists
self.overwrite_condition = overwrite_condition

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
Expand Down
32 changes: 17 additions & 15 deletions src/snowflake/snowpark/_internal/proto/ast.proto
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// N.B. This file is generated by `//ir-dsl-c`. DO NOT EDIT!
// Generated from `{[email protected]:snowflakedb/snowflake.git}/Snowpark/ast`.
// Generated from `{[email protected]:snowflake-eng/snowflake.git}/Snowpark/ast`.

syntax = "proto3";

Expand Down Expand Up @@ -987,7 +987,7 @@ message DataframeCollect {
repeated Tuple_String_String statement_params = 7;
}

// dataframe-io.ir:165
// dataframe-io.ir:167
message DataframeCopyIntoTable {
repeated Tuple_String_Expr copy_options = 1;
Expr df = 2;
Expand All @@ -1011,7 +1011,7 @@ message DataframeCount {
repeated Tuple_String_String statement_params = 4;
}

// dataframe-io.ir:148
// dataframe-io.ir:150
message DataframeCreateOrReplaceDynamicTable {
repeated Expr clustering_keys = 1;
google.protobuf.StringValue comment = 2;
Expand All @@ -1030,7 +1030,7 @@ message DataframeCreateOrReplaceDynamicTable {
string warehouse = 15;
}

// dataframe-io.ir:139
// dataframe-io.ir:141
message DataframeCreateOrReplaceView {
google.protobuf.StringValue comment = 1;
bool copy_grants = 2;
Expand Down Expand Up @@ -2685,7 +2685,7 @@ message WindowSpecRowsBetween {
WindowSpecExpr wnd = 4;
}

// dataframe-io.ir:116
// dataframe-io.ir:118
message WriteCopyIntoLocation {
bool block = 1;
repeated Tuple_String_Expr copy_options = 2;
Expand All @@ -2704,7 +2704,7 @@ message WriteCopyIntoLocation {
Expr writer = 15;
}

// dataframe-io.ir:123
// dataframe-io.ir:125
message WriteCsv {
bool block = 1;
repeated Tuple_String_Expr copy_options = 2;
Expand All @@ -2731,15 +2731,15 @@ message WriteFile {
}
}

// dataframe-io.ir:129
// dataframe-io.ir:131
message WriteInsertInto {
bool overwrite = 1;
SrcPosition src = 2;
NameRef table_name = 3;
Expr writer = 4;
}

// dataframe-io.ir:125
// dataframe-io.ir:127
message WriteJson {
bool block = 1;
repeated Tuple_String_Expr copy_options = 2;
Expand Down Expand Up @@ -2773,7 +2773,7 @@ message WritePandas {
string table_type = 13;
}

// dataframe-io.ir:127
// dataframe-io.ir:129
message WriteParquet {
bool block = 1;
repeated Tuple_String_Expr copy_options = 2;
Expand All @@ -2790,7 +2790,7 @@ message WriteParquet {
Expr writer = 13;
}

// dataframe-io.ir:121
// dataframe-io.ir:123
message WriteSave {
bool block = 1;
repeated Tuple_String_Expr copy_options = 2;
Expand Down Expand Up @@ -2821,9 +2821,11 @@ message WriteTable {
repeated Tuple_String_Expr iceberg_config = 10;
google.protobuf.Int64Value max_data_extension_time = 11;
SaveMode mode = 12;
SrcPosition src = 13;
repeated Tuple_String_String statement_params = 14;
NameRef table_name = 15;
string table_type = 16;
Expr writer = 17;
Expr overwrite_condition = 13;
SrcPosition src = 14;
repeated Tuple_String_String statement_params = 15;
google.protobuf.BoolValue table_exists = 16;
NameRef table_name = 17;
string table_type = 18;
Expr writer = 19;
}
65 changes: 56 additions & 9 deletions src/snowflake/snowpark/dataframe_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
warning,
)
from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType
from snowflake.snowpark.column import Column, _to_col_if_str
from snowflake.snowpark.column import Column, _to_col_if_str, _to_col_if_sql_expr
from snowflake.snowpark.exceptions import SnowparkClientException
from snowflake.snowpark.functions import sql_expr
from snowflake.snowpark.mock._connection import MockServerConnection
Expand Down Expand Up @@ -256,6 +256,7 @@ def save_as_table(
Dict[str, Union[str, Iterable[ColumnOrSqlExpr]]]
] = None,
table_exists: Optional[bool] = None,
overwrite_condition: Optional[ColumnOrSqlExpr] = None,
_emit_ast: bool = True,
**kwargs: Optional[Dict[str, Any]],
) -> Optional[AsyncJob]:
Expand All @@ -270,7 +271,9 @@ def save_as_table(

"append": Append data of this DataFrame to the existing table. Creates a table if it does not exist.

"overwrite": Overwrite the existing table by dropping old table.
"overwrite": Overwrite the existing table. By default, drops and recreates the table.
When ``overwrite_condition`` is specified, performs selective overwrite: deletes only
rows matching the condition, then inserts new data.

"truncate": Overwrite the existing table by truncating old table.

Expand Down Expand Up @@ -330,7 +333,12 @@ def save_as_table(
* iceberg_version: Overrides the version of iceberg to use. Defaults to 2 when unset.
table_exists: Optional parameter to specify if the table is known to exist or not.
Set to ``True`` if table exists, ``False`` if it doesn't, or ``None`` (default) for automatic detection.
Primarily useful for "append" and "truncate" modes to avoid running query for automatic detection.
Primarily useful for "append", "truncate", and "overwrite" with overwrite_condition modes to avoid running query for automatic detection.
overwrite_condition: Specifies the overwrite condition to perform atomic targeted delete-insert.
Can only be used when ``mode`` is "overwrite". When provided and the table exists, rows matching
the condition are atomically deleted and all rows from the DataFrame are inserted, preserving
non-matching rows. When not provided, the default "overwrite" behavior applies (drop and recreate table).
If the table does not exist, ``overwrite_condition`` is ignored and the table is created normally.


Example 1::
Expand Down Expand Up @@ -364,6 +372,21 @@ def save_as_table(
... "partition_by": ["a", bucket(3, col("b"))],
... }
>>> df.write.mode("overwrite").save_as_table("my_table", iceberg_config=iceberg_config) # doctest: +SKIP

Example 3::

Using overwrite_condition for targeted delete and insert:

>>> from snowflake.snowpark.functions import col
>>> df = session.create_dataframe([[1, "a"], [2, "b"], [3, "c"]], schema=["id", "val"])
>>> df.write.mode("overwrite").save_as_table("my_table", table_type="temporary")
>>> session.table("my_table").order_by("id").collect()
[Row(ID=1, VAL='a'), Row(ID=2, VAL='b'), Row(ID=3, VAL='c')]

>>> new_df = session.create_dataframe([[2, "updated2"], [5, "updated5"]], schema=["id", "val"])
>>> new_df.write.mode("overwrite").save_as_table("my_table", overwrite_condition="id = 1 or val = 'b'")
>>> session.table("my_table").order_by("id").collect()
[Row(ID=2, VAL='updated2'), Row(ID=3, VAL='c'), Row(ID=5, VAL='updated5')]
"""

statement_params = track_data_source_statement_params(
Expand Down Expand Up @@ -392,6 +415,8 @@ def save_as_table(
# change_tracking: Optional[bool] = None,
# copy_grants: bool = False,
# iceberg_config: Optional[dict] = None,
# table_exists: Optional[bool] = None,
# overwrite_condition: Optional[ColumnOrSqlExpr] = None,

build_table_name(expr.table_name, table_name)

Expand Down Expand Up @@ -433,6 +458,12 @@ def save_as_table(
t = expr.iceberg_config.add()
t._1 = k
build_expr_from_python_val(t._2, v)
if table_exists is not None:
expr.table_exists.value = table_exists
if overwrite_condition is not None:
build_expr_from_snowpark_column_or_sql_str(
expr.overwrite_condition, overwrite_condition
)

self._dataframe._session._ast_batch.eval(stmt)

Expand Down Expand Up @@ -486,18 +517,33 @@ def save_as_table(
f"Unsupported table type. Expected table types: {SUPPORTED_TABLE_TYPES}"
)

# overwrite_condition must be used with OVERWRITE mode only
if overwrite_condition is not None and save_mode != SaveMode.OVERWRITE:
raise ValueError(
f"'overwrite_condition' is only supported with mode='overwrite'. "
f"Got mode='{save_mode.value}'."
)

overwrite_condition_expr = (
_to_col_if_sql_expr(
overwrite_condition, "DataFrameWriter.save_as_table"
)._expression
if overwrite_condition is not None
else None
)

session = self._dataframe._session
needs_table_exists_check = save_mode in [
SaveMode.APPEND,
SaveMode.TRUNCATE,
] or (save_mode == SaveMode.OVERWRITE and overwrite_condition is not None)
if (
table_exists is None
and not isinstance(session._conn, MockServerConnection)
and save_mode
in [
SaveMode.APPEND,
SaveMode.TRUNCATE,
]
and needs_table_exists_check
):
# whether the table already exists in the database
# determines the compiled SQL for APPEND and TRUNCATE mode
# determines the compiled SQL for APPEND, TRUNCATE, and OVERWRITE with overwrite_condition
# if the table does not exist, we need to create it first;
# if the table exists, we can skip the creation step and insert data directly
table_exists = session._table_exists(table_name)
Expand All @@ -518,6 +564,7 @@ def save_as_table(
copy_grants,
iceberg_config,
table_exists,
overwrite_condition_expr,
)
snowflake_plan = session._analyzer.resolve(create_table_logic_plan)
result = session._conn.execute(
Expand Down
Loading