Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#### New Features

- Added support for targeted delete-insert via the `overwrite_condition` parameter in `DataFrameWriter.save_as_table`

#### Bug Fixes

#### Improvements
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,12 @@ def do_resolve_with_resolved_children(
child_attributes=resolved_child.attributes,
iceberg_config=iceberg_config,
table_exists=logical_plan.table_exists,
overwrite_condition=self.analyze(
logical_plan.overwrite_condition,
df_aliased_col_name_to_real_col_name,
)
if logical_plan.overwrite_condition
else None,
)

if isinstance(logical_plan, Limit):
Expand Down
49 changes: 48 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,7 @@ def save_as_table(
child_attributes: Optional[List[Attribute]],
iceberg_config: Optional[dict] = None,
table_exists: Optional[bool] = None,
overwrite_condition: Optional[str] = None,
) -> SnowflakePlan:
"""Returns a SnowflakePlan to materialize the child plan into a table.

Expand Down Expand Up @@ -1417,6 +1418,47 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace, error):
referenced_ctes=child.referenced_ctes,
)

def get_overwrite_delete_insert_plan(child: SnowflakePlan):
"""Build a plan for targeted delete + insert with transaction.

Deletes rows matching the overwrite_condition condition, then inserts
all rows from the source DataFrame. Wrapped in a transaction for atomicity.
"""
child = self.add_result_scan_if_not_select(child)

return SnowflakePlan(
[
*child.queries[:-1],
Query("BEGIN TRANSACTION"),
Query(
delete_statement(
table_name=full_table_name,
condition=overwrite_condition,
source_data=None,
),
params=child.queries[-1].params,
is_ddl_on_temp_object=is_temp_table_type,
),
Query(
insert_into_statement(
table_name=full_table_name,
child=child.queries[-1].sql,
column_names=column_names,
),
params=child.queries[-1].params,
is_ddl_on_temp_object=is_temp_table_type,
),
Query("COMMIT"),
],
schema_query=None,
post_actions=child.post_actions,
expr_to_alias={},
source_plan=source_plan,
api_calls=child.api_calls,
session=self.session,
referenced_ctes=child.referenced_ctes,
)

if mode == SaveMode.APPEND:
assert table_exists is not None
if table_exists:
Expand Down Expand Up @@ -1446,7 +1488,12 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace, error):
return get_create_table_as_select_plan(child, replace=True, error=True)

elif mode == SaveMode.OVERWRITE:
return get_create_table_as_select_plan(child, replace=True, error=True)
if overwrite_condition is not None and table_exists:
# Selective overwrite: delete matching rows, then insert
return get_overwrite_delete_insert_plan(child)
else:
# Default overwrite: drop and recreate table
return get_create_table_as_select_plan(child, replace=True, error=True)

elif mode == SaveMode.IGNORE:
return get_create_table_as_select_plan(child, replace=False, error=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __init__(
copy_grants: bool = False,
iceberg_config: Optional[dict] = None,
table_exists: Optional[bool] = None,
overwrite_condition: Optional[Expression] = None,
) -> None:
super().__init__()

Expand All @@ -267,6 +268,7 @@ def __init__(
# whether the table already exists in the database
# determines the compiled SQL for APPEND and TRUNCATE mode
self.table_exists = table_exists
self.overwrite_condition = overwrite_condition

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
Expand Down
32 changes: 17 additions & 15 deletions src/snowflake/snowpark/_internal/proto/ast.proto
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// N.B. This file is generated by `//ir-dsl-c`. DO NOT EDIT!
// Generated from `{[email protected]:snowflakedb/snowflake.git}/Snowpark/ast`.
// Generated from `{[email protected]:snowflake-eng/snowflake.git}/Snowpark/ast`.

syntax = "proto3";

Expand Down Expand Up @@ -987,7 +987,7 @@ message DataframeCollect {
repeated Tuple_String_String statement_params = 7;
}

// dataframe-io.ir:165
// dataframe-io.ir:167
message DataframeCopyIntoTable {
repeated Tuple_String_Expr copy_options = 1;
Expr df = 2;
Expand All @@ -1011,7 +1011,7 @@ message DataframeCount {
repeated Tuple_String_String statement_params = 4;
}

// dataframe-io.ir:148
// dataframe-io.ir:150
message DataframeCreateOrReplaceDynamicTable {
repeated Expr clustering_keys = 1;
google.protobuf.StringValue comment = 2;
Expand All @@ -1030,7 +1030,7 @@ message DataframeCreateOrReplaceDynamicTable {
string warehouse = 15;
}

// dataframe-io.ir:139
// dataframe-io.ir:141
message DataframeCreateOrReplaceView {
google.protobuf.StringValue comment = 1;
bool copy_grants = 2;
Expand Down Expand Up @@ -2685,7 +2685,7 @@ message WindowSpecRowsBetween {
WindowSpecExpr wnd = 4;
}

// dataframe-io.ir:116
// dataframe-io.ir:118
message WriteCopyIntoLocation {
bool block = 1;
repeated Tuple_String_Expr copy_options = 2;
Expand All @@ -2704,7 +2704,7 @@ message WriteCopyIntoLocation {
Expr writer = 15;
}

// dataframe-io.ir:123
// dataframe-io.ir:125
message WriteCsv {
bool block = 1;
repeated Tuple_String_Expr copy_options = 2;
Expand All @@ -2731,15 +2731,15 @@ message WriteFile {
}
}

// dataframe-io.ir:129
// dataframe-io.ir:131
message WriteInsertInto {
bool overwrite = 1;
SrcPosition src = 2;
NameRef table_name = 3;
Expr writer = 4;
}

// dataframe-io.ir:125
// dataframe-io.ir:127
message WriteJson {
bool block = 1;
repeated Tuple_String_Expr copy_options = 2;
Expand Down Expand Up @@ -2773,7 +2773,7 @@ message WritePandas {
string table_type = 13;
}

// dataframe-io.ir:127
// dataframe-io.ir:129
message WriteParquet {
bool block = 1;
repeated Tuple_String_Expr copy_options = 2;
Expand All @@ -2790,7 +2790,7 @@ message WriteParquet {
Expr writer = 13;
}

// dataframe-io.ir:121
// dataframe-io.ir:123
message WriteSave {
bool block = 1;
repeated Tuple_String_Expr copy_options = 2;
Expand Down Expand Up @@ -2821,9 +2821,11 @@ message WriteTable {
repeated Tuple_String_Expr iceberg_config = 10;
google.protobuf.Int64Value max_data_extension_time = 11;
SaveMode mode = 12;
SrcPosition src = 13;
repeated Tuple_String_String statement_params = 14;
NameRef table_name = 15;
string table_type = 16;
Expr writer = 17;
Expr overwrite_condition = 13;
SrcPosition src = 14;
repeated Tuple_String_String statement_params = 15;
google.protobuf.BoolValue table_exists = 16;
NameRef table_name = 17;
string table_type = 18;
Expr writer = 19;
}
59 changes: 51 additions & 8 deletions src/snowflake/snowpark/dataframe_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
warning,
)
from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType
from snowflake.snowpark.column import Column, _to_col_if_str
from snowflake.snowpark.column import Column, _to_col_if_str, _to_col_if_sql_expr
from snowflake.snowpark.exceptions import SnowparkClientException
from snowflake.snowpark.functions import sql_expr
from snowflake.snowpark.mock._connection import MockServerConnection
Expand Down Expand Up @@ -256,6 +256,7 @@ def save_as_table(
Dict[str, Union[str, Iterable[ColumnOrSqlExpr]]]
] = None,
table_exists: Optional[bool] = None,
overwrite_condition: Optional[ColumnOrSqlExpr] = None,
_emit_ast: bool = True,
**kwargs: Optional[Dict[str, Any]],
) -> Optional[AsyncJob]:
Expand Down Expand Up @@ -330,7 +331,10 @@ def save_as_table(
* iceberg_version: Overrides the version of iceberg to use. Defaults to 2 when unset.
table_exists: Optional parameter to specify if the table is known to exist or not.
Set to ``True`` if table exists, ``False`` if it doesn't, or ``None`` (default) for automatic detection.
Primarily useful for "append" and "truncate" modes to avoid running query for automatic detection.
Primarily useful for "append", "truncate", and "overwrite" with overwrite_condition modes to avoid running query for automatic detection.
overwrite_condition: Specifies the overwrite condition to perform atomic targeted delete-insert.
Can only be used when ``mode`` is "overwrite" and the table exists. Rows matching the
condition are deleted from the target table, then all rows from the DataFrame are inserted.


Example 1::
Expand Down Expand Up @@ -364,6 +368,21 @@ def save_as_table(
... "partition_by": ["a", bucket(3, col("b"))],
... }
>>> df.write.mode("overwrite").save_as_table("my_table", iceberg_config=iceberg_config) # doctest: +SKIP

Example 3::

Using overwrite_condition for targeted delete and insert:

>>> from snowflake.snowpark.functions import col
>>> df = session.create_dataframe([[1, "a"], [2, "b"], [3, "c"]], schema=["id", "val"])
>>> df.write.mode("overwrite").save_as_table("my_table", table_type="temporary")
>>> session.table("my_table").order_by("id").collect()
[Row(ID=1, VAL='a'), Row(ID=2, VAL='b'), Row(ID=3, VAL='c')]

>>> new_df = session.create_dataframe([[2, "updated2"], [5, "updated5"]], schema=["id", "val"])
>>> new_df.write.mode("overwrite").save_as_table("my_table", overwrite_condition="id = 1 or val = 'b'")
>>> session.table("my_table").order_by("id").collect()
[Row(ID=2, VAL='updated2'), Row(ID=3, VAL='c'), Row(ID=5, VAL='updated5')]
"""

statement_params = track_data_source_statement_params(
Expand Down Expand Up @@ -392,6 +411,8 @@ def save_as_table(
# change_tracking: Optional[bool] = None,
# copy_grants: bool = False,
# iceberg_config: Optional[dict] = None,
# table_exists: Optional[bool] = None,
# overwrite_condition: Optional[ColumnOrSqlExpr] = None,

build_table_name(expr.table_name, table_name)

Expand Down Expand Up @@ -433,6 +454,12 @@ def save_as_table(
t = expr.iceberg_config.add()
t._1 = k
build_expr_from_python_val(t._2, v)
if table_exists is not None:
expr.table_exists.value = table_exists
if overwrite_condition is not None:
build_expr_from_snowpark_column_or_sql_str(
expr.overwrite_condition, overwrite_condition
)

self._dataframe._session._ast_batch.eval(stmt)

Expand Down Expand Up @@ -486,18 +513,33 @@ def save_as_table(
f"Unsupported table type. Expected table types: {SUPPORTED_TABLE_TYPES}"
)

# overwrite_condition must be used with OVERWRITE mode only
if overwrite_condition is not None and save_mode != SaveMode.OVERWRITE:
raise ValueError(
f"'overwrite_condition' is only supported with mode='overwrite'. "
f"Got mode='{save_mode.value}'."
)

overwrite_condition_expr = (
_to_col_if_sql_expr(
overwrite_condition, "DataFrameWriter.save_as_table"
)._expression
if overwrite_condition is not None
else None
)

session = self._dataframe._session
needs_table_exists_check = save_mode in [
SaveMode.APPEND,
SaveMode.TRUNCATE,
] or (save_mode == SaveMode.OVERWRITE and overwrite_condition is not None)
if (
table_exists is None
and not isinstance(session._conn, MockServerConnection)
and save_mode
in [
SaveMode.APPEND,
SaveMode.TRUNCATE,
]
and needs_table_exists_check
):
# whether the table already exists in the database
# determines the compiled SQL for APPEND and TRUNCATE mode
# determines the compiled SQL for APPEND, TRUNCATE, and OVERWRITE with overwrite_condition
# if the table does not exist, we need to create it first;
# if the table exists, we can skip the creation step and insert data directly
table_exists = session._table_exists(table_name)
Expand All @@ -518,6 +560,7 @@ def save_as_table(
copy_grants,
iceberg_config,
table_exists,
overwrite_condition_expr,
)
snowflake_plan = session._analyzer.resolve(create_table_logic_plan)
result = session._conn.execute(
Expand Down
Loading
Loading