Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#### New Features

- Added support for targeted delete-insert via the `override_condition` parameter in `DataFrameWriter.save_as_table`

#### Bug Fixes

#### Improvements
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,12 @@ def do_resolve_with_resolved_children(
child_attributes=resolved_child.attributes,
iceberg_config=iceberg_config,
table_exists=logical_plan.table_exists,
override_condition=self.analyze(
logical_plan.override_condition,
df_aliased_col_name_to_real_col_name,
)
if logical_plan.override_condition
else None,
)

if isinstance(logical_plan, Limit):
Expand Down
65 changes: 56 additions & 9 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,7 @@ def save_as_table(
child_attributes: Optional[List[Attribute]],
iceberg_config: Optional[dict] = None,
table_exists: Optional[bool] = None,
override_condition: Optional[str] = None,
) -> SnowflakePlan:
"""Returns a SnowflakePlan to materialize the child plan into a table.

Expand Down Expand Up @@ -1417,19 +1418,65 @@ def get_create_and_insert_plan(child: SnowflakePlan, replace, error):
referenced_ctes=child.referenced_ctes,
)

def get_override_delete_insert_plan(child: SnowflakePlan):
"""Build a plan for targeted delete + insert with transaction.

Deletes rows matching the override_condition condition, then inserts
all rows from the source DataFrame. Wrapped in a transaction for atomicity.
"""
child = self.add_result_scan_if_not_select(child)

return SnowflakePlan(
[
*child.queries[:-1],
Query("BEGIN TRANSACTION"),
Query(
delete_statement(
table_name=full_table_name,
condition=override_condition,
source_data=None,
),
params=child.queries[-1].params,
is_ddl_on_temp_object=is_temp_table_type,
),
Query(
insert_into_statement(
table_name=full_table_name,
child=child.queries[-1].sql,
column_names=column_names,
),
params=child.queries[-1].params,
is_ddl_on_temp_object=is_temp_table_type,
),
Query("COMMIT"),
],
schema_query=None,
post_actions=child.post_actions,
expr_to_alias={},
source_plan=source_plan,
api_calls=child.api_calls,
session=self.session,
referenced_ctes=child.referenced_ctes,
)

if mode == SaveMode.APPEND:
assert table_exists is not None
if table_exists:
return self.build(
lambda x: insert_into_statement(
table_name=full_table_name,
child=x,
column_names=column_names,
),
child,
source_plan,
)
if override_condition is not None:
return get_override_delete_insert_plan(child)
else:
# Normal append without override_condition
return self.build(
lambda x: insert_into_statement(
table_name=full_table_name,
child=x,
column_names=column_names,
),
child,
source_plan,
)
else:
# Table doesn't exist, just create and insert (override_condition is no-op)
return get_create_and_insert_plan(child, replace=False, error=False)

elif mode == SaveMode.TRUNCATE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __init__(
copy_grants: bool = False,
iceberg_config: Optional[dict] = None,
table_exists: Optional[bool] = None,
override_condition: Optional[Expression] = None,
) -> None:
super().__init__()

Expand All @@ -267,6 +268,7 @@ def __init__(
# whether the table already exists in the database
# determines the compiled SQL for APPEND and TRUNCATE mode
self.table_exists = table_exists
self.override_condition = override_condition

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
Expand Down
37 changes: 36 additions & 1 deletion src/snowflake/snowpark/dataframe_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
warning,
)
from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType
from snowflake.snowpark.column import Column, _to_col_if_str
from snowflake.snowpark.column import Column, _to_col_if_str, _to_col_if_sql_expr
from snowflake.snowpark.exceptions import SnowparkClientException
from snowflake.snowpark.functions import sql_expr
from snowflake.snowpark.mock._connection import MockServerConnection
Expand Down Expand Up @@ -256,6 +256,7 @@ def save_as_table(
Dict[str, Union[str, Iterable[ColumnOrSqlExpr]]]
] = None,
table_exists: Optional[bool] = None,
override_condition: Optional[ColumnOrSqlExpr] = None,
_emit_ast: bool = True,
**kwargs: Optional[Dict[str, Any]],
) -> Optional[AsyncJob]:
Expand Down Expand Up @@ -331,6 +332,9 @@ def save_as_table(
table_exists: Optional parameter to specify if the table is known to exist or not.
Set to ``True`` if table exists, ``False`` if it doesn't, or ``None`` (default) for automatic detection.
Primarily useful for "append" and "truncate" modes to avoid running query for automatic detection.
override_condition: Specifies the override condition to perform atomic targeted delete-insert.
Can only be used when ``mode`` is "append" and the table exists. Rows matching the
condition are deleted from the target table, then all rows from the DataFrame are inserted.


Example 1::
Expand Down Expand Up @@ -364,6 +368,21 @@ def save_as_table(
... "partition_by": ["a", bucket(3, col("b"))],
... }
>>> df.write.mode("overwrite").save_as_table("my_table", iceberg_config=iceberg_config) # doctest: +SKIP

Example 3::

Using override_condition for targeted delete and insert:

>>> from snowflake.snowpark.functions import col
>>> df = session.create_dataframe([[1, "a"], [2, "b"], [3, "c"]], schema=["id", "val"])
>>> df.write.mode("overwrite").save_as_table("my_table", table_type="temporary")
>>> session.table("my_table").order_by("id").collect()
[Row(ID=1, VAL='a'), Row(ID=2, VAL='b'), Row(ID=3, VAL='c')]

>>> new_df = session.create_dataframe([[2, "updated2"], [5, "updated5"]], schema=["id", "val"])
>>> new_df.write.mode("append").save_as_table("my_table", override_condition="id = 1 or val = 'b'")
>>> session.table("my_table").order_by("id").collect()
[Row(ID=2, VAL='updated2'), Row(ID=3, VAL='c'), Row(ID=5, VAL='updated5')]
"""

statement_params = track_data_source_statement_params(
Expand Down Expand Up @@ -486,6 +505,21 @@ def save_as_table(
f"Unsupported table type. Expected table types: {SUPPORTED_TABLE_TYPES}"
)

# override_condition must be used with APPEND mode
if override_condition is not None and save_mode != SaveMode.APPEND:
raise ValueError(
f"'override_condition' is only supported with mode='append'. "
f"Got mode='{save_mode.value}'."
)

override_condition_expr = (
_to_col_if_sql_expr(
override_condition, "DataFrameWriter.save_as_table"
)._expression
if override_condition is not None
else None
)

session = self._dataframe._session
if (
table_exists is None
Expand Down Expand Up @@ -518,6 +552,7 @@ def save_as_table(
copy_grants,
iceberg_config,
table_exists,
override_condition_expr,
)
snowflake_plan = session._analyzer.resolve(create_table_logic_plan)
result = session._conn.execute(
Expand Down
148 changes: 148 additions & 0 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4566,6 +4566,154 @@ def test_write_table_with_clustering_keys_and_comment(
Utils.drop_table(session, table_name3)


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="override_condition is a SQL feature",
run=False,
)
def test_write_table_with_override_condition(session):
"""Test override_condition parameter for targeted delete + insert."""
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
try:
# Setup and verify initial data
initial_df = session.create_dataframe(
[[1, "a"], [2, "b"], [3, "c"]],
schema=StructType(
[StructField("id", IntegerType()), StructField("val", StringType())]
),
)
initial_df.write.mode("overwrite").save_as_table(table_name)
result = session.table(table_name).order_by("id").collect()
assert result == [
Row(ID=1, VAL="a"),
Row(ID=2, VAL="b"),
Row(ID=3, VAL="c"),
]

# Test 1: override_condition with SQL string expr
new_df1 = session.create_dataframe(
[[2, "updated2"], [5, "new5"]],
schema=StructType(
[StructField("id", IntegerType()), StructField("val", StringType())]
),
)
new_df1.write.mode("append").save_as_table(
table_name, override_condition="id = 1 or val = 'b'"
)
result = session.table(table_name).order_by("id").collect()
# id=1 and id=2 (val='b') deleted, new rows inserted
assert result == [
Row(ID=2, VAL="updated2"),
Row(ID=3, VAL="c"),
Row(ID=5, VAL="new5"),
]

# Test 2: override_condition with Column expr
new_df2 = session.create_dataframe(
[[2, "replaced2"], [4, "new4"]],
schema=StructType(
[StructField("id", IntegerType()), StructField("val", StringType())]
),
)
new_df2.write.mode("append").save_as_table(
table_name, override_condition=col("id") == 2
)
result = session.table(table_name).order_by("id").collect()
# id=2 deleted, new rows inserted
assert result == [
Row(ID=2, VAL="replaced2"),
Row(ID=3, VAL="c"),
Row(ID=4, VAL="new4"),
Row(ID=5, VAL="new5"),
]

# Test 3: override_condition with multiple Column expr
new_df3 = session.create_dataframe(
[[6, "new6"]],
schema=StructType(
[StructField("id", IntegerType()), StructField("val", StringType())]
),
)
new_df3.write.mode("append").save_as_table(
table_name, override_condition=(col("id") > 4) | (col("val") == "c")
)
result = session.table(table_name).order_by("id").collect()
# id=3 (val='c') and id=5 (id > 4) deleted, id=4 remains (4 is not > 4), new row inserted
assert result == [
Row(ID=2, VAL="replaced2"),
Row(ID=4, VAL="new4"),
Row(ID=6, VAL="new6"),
]

# Test 4: override_condition that matches all rows
new_df4 = session.create_dataframe(
[[10, "new"]],
schema=StructType(
[StructField("id", IntegerType()), StructField("val", StringType())]
),
)
new_df4.write.mode("append").save_as_table(
table_name, override_condition="id > 0"
)
result = session.table(table_name).collect()
assert result == [Row(ID=10, VAL="new")]

# Test 5: override_condition that matches no rows
new_df5 = session.create_dataframe(
[[20, "another"]],
schema=StructType(
[StructField("id", IntegerType()), StructField("val", StringType())]
),
)
new_df5.write.mode("append").save_as_table(
table_name, override_condition="id = 999"
)
result = session.table(table_name).order_by("id").collect()
assert result == [
Row(ID=10, VAL="new"),
Row(ID=20, VAL="another"),
]

finally:
Utils.drop_table(session, table_name)


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="override_condition is a SQL feature",
run=False,
)
@pytest.mark.parametrize(
"invalid_mode", ["overwrite", "truncate", "errorifexists", "ignore"]
)
def test_write_table_with_override_condition_edge_cases(session, invalid_mode):
"""Test override_condition edge cases: table not exists, and invalid modes."""
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
try:
# Edge case 1: Table doesn't exist - override_condition is no-op
df = session.create_dataframe(
[[1, "a"], [2, "b"]],
schema=StructType(
[StructField("id", IntegerType()), StructField("val", StringType())]
),
)
df.write.mode("append").save_as_table(table_name, override_condition="id = 999")
result = session.table(table_name).order_by("id").collect()
assert result == [Row(ID=1, VAL="a"), Row(ID=2, VAL="b")]

# Edge case 2: Invalid mode raises ValueError
with pytest.raises(
ValueError,
match="'override_condition' is only supported with mode='append'",
):
df.write.mode(invalid_mode).save_as_table(
table_name, override_condition="id = 1"
)

finally:
Utils.drop_table(session, table_name)


@pytest.mark.xfail(
"config.getoption('local_testing_mode', default=False)",
reason="Clustering is a SQL feature",
Expand Down
Loading