Skip to content

Commit 1cbebc1

Browse files
authored
SNOW-1794373: Support DataFrameWriter.insertInto/insert_into (#2835)
1 parent 9bcbd59 commit 1cbebc1

File tree

5 files changed

+175
-35
lines changed

5 files changed

+175
-35
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
- Added `Catalog` class to manage snowflake objects. It can be accessed via `Session.catalog`.
4646
- Added support for specifying a schema string (including implicit struct syntax) when calling `DataFrame.create_dataframe`.
47+
- Added support for `DataFrameWriter.insert_into/insertInto`. This method also supports local testing mode.
4748

4849
#### Improvements
4950

docs/source/snowpark/io.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ Input/Output
4444
DataFrameWriter.save
4545
DataFrameWriter.saveAsTable
4646
DataFrameWriter.save_as_table
47+
DataFrameWriter.insertInto
48+
DataFrameWriter.insert_into
4749
FileOperation.get
4850
FileOperation.get_stream
4951
FileOperation.put

src/snowflake/snowpark/dataframe_writer.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType
4646
from snowflake.snowpark.column import Column, _to_col_if_str
47+
from snowflake.snowpark.exceptions import SnowparkClientException
4748
from snowflake.snowpark.functions import sql_expr
4849
from snowflake.snowpark.mock._connection import MockServerConnection
4950
from snowflake.snowpark.row import Row
@@ -933,4 +934,46 @@ def parquet(
933934
**copy_options,
934935
)
935936

937+
@publicapi
938+
def insert_into(
939+
self, table_name: Union[str, Iterable[str]], overwrite: bool = False
940+
) -> None:
941+
"""
942+
Inserts the content of the DataFrame to the specified table.
943+
It requires that the schema of the DataFrame is the same as the schema of the table.
944+
945+
Args:
946+
table_name: A string or list of strings representing table name.
947+
If input is a string, it represents the table name; if input is of type iterable of strings,
948+
it represents the fully-qualified object identifier (database name, schema name, and table name).
949+
overwrite: If True, the content of table will be overwritten.
950+
If False, the data will be appended to the table. Default is False.
951+
952+
Example::
953+
954+
>>> # save this dataframe to a json file on the session stage
955+
>>> df = session.create_dataframe([["John", "Berry"]], schema = ["FIRST_NAME", "LAST_NAME"])
956+
>>> df.write.save_as_table("my_table", table_type="temporary")
957+
>>> df2 = session.create_dataframe([["Rick", "Berry"]], schema = ["FIRST_NAME", "LAST_NAME"])
958+
>>> df2.write.insert_into("my_table")
959+
>>> session.table("my_table").collect()
960+
[Row(FIRST_NAME='John', LAST_NAME='Berry'), Row(FIRST_NAME='Rick', LAST_NAME='Berry')]
961+
"""
962+
full_table_name = (
963+
table_name if isinstance(table_name, str) else ".".join(table_name)
964+
)
965+
validate_object_name(full_table_name)
966+
qualified_table_name = (
967+
parse_table_name(table_name) if isinstance(table_name, str) else table_name
968+
)
969+
if not self._dataframe._session._table_exists(qualified_table_name):
970+
raise SnowparkClientException(
971+
f"Table {full_table_name} does not exist or not authorized."
972+
)
973+
974+
self.save_as_table(
975+
qualified_table_name, mode="truncate" if overwrite else "append"
976+
)
977+
978+
insertInto = insert_into
936979
saveAsTable = save_as_table

src/snowflake/snowpark/session.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4084,42 +4084,44 @@ def _table_exists(self, raw_table_name: Iterable[str]):
40844084
""" """
40854085
# implementation based upon: https://docs.snowflake.com/en/sql-reference/name-resolution.html
40864086
qualified_table_name = list(raw_table_name)
4087-
if len(qualified_table_name) == 1:
4088-
# name in the form of "table"
4089-
tables = self._run_query(
4090-
f"show tables like '{strip_double_quotes_in_like_statement_in_table_name(qualified_table_name[0])}'"
4091-
)
4092-
elif len(qualified_table_name) == 2:
4093-
# name in the form of "schema.table" omitting database
4094-
# schema: qualified_table_name[0]
4095-
# table: qualified_table_name[1]
4096-
tables = self._run_query(
4097-
f"show tables like '{strip_double_quotes_in_like_statement_in_table_name(qualified_table_name[1])}' in schema {qualified_table_name[0]}"
4098-
)
4099-
elif len(qualified_table_name) == 3:
4100-
# name in the form of "database.schema.table"
4101-
# database: qualified_table_name[0]
4102-
# schema: qualified_table_name[1]
4103-
# table: qualified_table_name[2]
4104-
# special case: (''<database_name>..<object_name>''), by following
4105-
# https://docs.snowflake.com/en/sql-reference/name-resolution#resolution-when-schema-omitted-double-dot-notation
4106-
# The two dots indicate that the schema name is not specified.
4107-
# The PUBLIC default schema is always referenced.
4108-
condition = (
4109-
f"schema {qualified_table_name[0]}.PUBLIC"
4110-
if qualified_table_name[1] == ""
4111-
else f"schema {qualified_table_name[0]}.{qualified_table_name[1]}"
4112-
)
4113-
tables = self._run_query(
4114-
f"show tables like '{strip_double_quotes_in_like_statement_in_table_name(qualified_table_name[2])}' in {condition}"
4115-
)
4087+
if isinstance(self._conn, MockServerConnection):
4088+
return self._conn.entity_registry.is_existing_table(qualified_table_name)
41164089
else:
4117-
# we do not support len(qualified_table_name) > 3 for now
4118-
raise SnowparkClientExceptionMessages.GENERAL_INVALID_OBJECT_NAME(
4119-
".".join(raw_table_name)
4120-
)
4121-
4122-
return tables is not None and len(tables) > 0
4090+
if len(qualified_table_name) == 1:
4091+
# name in the form of "table"
4092+
tables = self._run_query(
4093+
f"show tables like '{strip_double_quotes_in_like_statement_in_table_name(qualified_table_name[0])}'"
4094+
)
4095+
elif len(qualified_table_name) == 2:
4096+
# name in the form of "schema.table" omitting database
4097+
# schema: qualified_table_name[0]
4098+
# table: qualified_table_name[1]
4099+
tables = self._run_query(
4100+
f"show tables like '{strip_double_quotes_in_like_statement_in_table_name(qualified_table_name[1])}' in schema {qualified_table_name[0]}"
4101+
)
4102+
elif len(qualified_table_name) == 3:
4103+
# name in the form of "database.schema.table"
4104+
# database: qualified_table_name[0]
4105+
# schema: qualified_table_name[1]
4106+
# table: qualified_table_name[2]
4107+
# special case: (''<database_name>..<object_name>''), by following
4108+
# https://docs.snowflake.com/en/sql-reference/name-resolution#resolution-when-schema-omitted-double-dot-notation
4109+
# The two dots indicate that the schema name is not specified.
4110+
# The PUBLIC default schema is always referenced.
4111+
condition = (
4112+
f"schema {qualified_table_name[0]}.PUBLIC"
4113+
if qualified_table_name[1] == ""
4114+
else f"schema {qualified_table_name[0]}.{qualified_table_name[1]}"
4115+
)
4116+
tables = self._run_query(
4117+
f"show tables like '{strip_double_quotes_in_like_statement_in_table_name(qualified_table_name[2])}' in {condition}"
4118+
)
4119+
else:
4120+
# we do not support len(qualified_table_name) > 3 for now
4121+
raise SnowparkClientExceptionMessages.GENERAL_INVALID_OBJECT_NAME(
4122+
".".join(raw_table_name)
4123+
)
4124+
return tables is not None and len(tables) > 0
41234125

41244126
def _explain_query(self, query: str) -> Optional[str]:
41254127
try:

tests/integ/scala/test_dataframe_writer_suite.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import pytest
99

1010
import snowflake.connector.errors
11+
12+
from snowflake.snowpark.exceptions import SnowparkClientException
1113
from snowflake.snowpark import Row
1214
from snowflake.snowpark._internal.utils import TempObjectType, parse_table_name
1315
from snowflake.snowpark.exceptions import SnowparkSQLException
@@ -874,3 +876,93 @@ def test_writer_parquet(session, tmpdir_factory, local_testing_mode):
874876
Utils.assert_rows_count(data4, ROWS_COUNT)
875877
finally:
876878
Utils.drop_stage(session, temp_stage)
879+
880+
881+
def test_insert_into(session, local_testing_mode):
882+
"""
883+
Test the insert_into API with positive and negative test cases.
884+
"""
885+
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
886+
887+
try:
888+
# Create a DataFrame with initial data
889+
df = session.create_dataframe(
890+
[["Alice", "Smith"], ["Bob", "Brown"]],
891+
schema=["FIRST_NAME", "LAST_NAME"],
892+
)
893+
df.write.save_as_table(table_name)
894+
895+
# Positive Test: Append data to the table
896+
df_append = session.create_dataframe(
897+
[["Charlie", "White"]],
898+
schema=["FIRST_NAME", "LAST_NAME"],
899+
)
900+
df_append.write.insert_into(table_name)
901+
Utils.check_answer(
902+
session.table(table_name),
903+
[
904+
Row(FIRST_NAME="Alice", LAST_NAME="Smith"),
905+
Row(FIRST_NAME="Bob", LAST_NAME="Brown"),
906+
Row(FIRST_NAME="Charlie", LAST_NAME="White"),
907+
],
908+
)
909+
910+
# Positive Test: Overwrite data in the table
911+
df_overwrite = session.create_dataframe(
912+
[["David", "Green"]],
913+
schema=["FIRST_NAME", "LAST_NAME"],
914+
)
915+
df_overwrite.write.insert_into(table_name, overwrite=True)
916+
Utils.check_answer(
917+
session.table(table_name), [Row(FIRST_NAME="David", LAST_NAME="Green")]
918+
)
919+
920+
# Negative Test: Schema mismatch, more columns
921+
df_more_columns = session.create_dataframe(
922+
[["Extra", "Column", 123]],
923+
schema=["FIRST_NAME", "LAST_NAME", "AGE"],
924+
)
925+
with pytest.raises(
926+
SnowparkSQLException,
927+
match="Insert value list does not match column list expecting 2 but got 3"
928+
if not local_testing_mode
929+
else "Cannot append because incoming data has different schema",
930+
):
931+
df_more_columns.write.insert_into(table_name)
932+
933+
# Negative Test: Schema mismatch, less columns
934+
df_less_column = session.create_dataframe(
935+
[["Column"]],
936+
schema=["FIRST_NAME"],
937+
)
938+
with pytest.raises(
939+
SnowparkSQLException,
940+
match="Insert value list does not match column list expecting 2 but got 1"
941+
if not local_testing_mode
942+
else "Cannot append because incoming data has different schema",
943+
):
944+
df_less_column.write.insert_into(table_name)
945+
946+
# Negative Test: Schema mismatch, type
947+
df_not_same_type = session.create_dataframe(
948+
[[[1, 2, 3, 4], False]],
949+
schema=["FIRST_NAME", "LAST_NAME"],
950+
)
951+
952+
if not local_testing_mode:
953+
# SNOW-1890315: Local Testing missing type coercion check
954+
with pytest.raises(
955+
SnowparkSQLException,
956+
match="Expression type does not match column data type",
957+
):
958+
df_not_same_type.write.insert_into(table_name)
959+
960+
# Negative Test: Table does not exist
961+
with pytest.raises(
962+
SnowparkClientException,
963+
match="Table non_existent_table does not exist or not authorized.",
964+
):
965+
df.write.insert_into("non_existent_table")
966+
967+
finally:
968+
Utils.drop_table(session, table_name)

0 commit comments

Comments
 (0)