Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 57 additions & 19 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import tempfile
import warnings
from array import array
from functools import reduce
from functools import partial, reduce
from logging import getLogger
from threading import RLock
from types import ModuleType
Expand All @@ -39,9 +39,11 @@
from snowflake.connector import ProgrammingError, SnowflakeConnection
from snowflake.connector.options import installed_pandas, pandas
from snowflake.connector.pandas_tools import write_pandas
from snowflake.snowpark._internal.analyzer import analyzer_utils
from snowflake.snowpark._internal.analyzer.analyzer import Analyzer
from snowflake.snowpark._internal.analyzer.analyzer_utils import result_scan_statement
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
attribute_to_schema_string,
result_scan_statement,
)
from snowflake.snowpark._internal.analyzer.datatype_mapper import str_to_sql
from snowflake.snowpark._internal.analyzer.expression import Attribute
from snowflake.snowpark._internal.analyzer.select_statement import (
Expand Down Expand Up @@ -2975,7 +2977,9 @@ def write_pandas(
raise pe

if success:
table = self.table(location, _emit_ast=False)
table = self.table(
location, is_temp_table_for_cleanup=True, _emit_ast=False
)
set_api_call_source(table, "Session.write_pandas")

# AST.
Expand Down Expand Up @@ -3028,6 +3032,31 @@ def write_pandas(
str(ci_output)
)

def _initialize_temp_table_with_schema(
self, temp_table_name: str, schema: StructType
) -> bool:
"""Creates a temp table for specified schema.

Args:
temp_table_name: table name
schema: user provided StructType schema

Returns:
True table was created successfully, else False
"""
try:
schema_string = attribute_to_schema_string(schema._to_attributes())
self._run_query(
f"CREATE SCOPED TEMP TABLE {temp_table_name} ({schema_string})"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not particularly familiar with scoped temp tables. When do they get cleaned up if not explicitly deleted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When run outside of a stored proc, using SCOPED keyword is a no-op. Within stored proc, it is mainly used to limit the snowpark internal created temp object scope. This is not properly documented and I learnt about this by asking out in our slack channels

)
except ProgrammingError as e:
_logger.debug(
f"Cannot create temp table for specified schema, fall back to inferring "
f"schema string from select query. Exception: {str(e)}"
)
return False
return True

@publicapi
def create_dataframe(
self,
Expand Down Expand Up @@ -3130,16 +3159,36 @@ def create_dataframe(
)
sf_schema = self._conn._get_current_parameter("schema", quoted=False)

table = self.write_pandas(
# If the user specifies schema for their dataframe, we try out best to match
# it by create a temp table with the specified schema, and load the data into
# the temp table. If we fail, go back to old method using infer schema.
write_pandas_partial = partial(
self.write_pandas,
data,
temp_table_name,
database=sf_database,
schema=sf_schema,
quote_identifiers=True,
auto_create_table=True,
table_type="temporary",
use_logical_type=self._use_logical_type_for_create_df,
)
if isinstance(
schema, StructType
) and self._initialize_temp_table_with_schema(temp_table_name, schema):
try:
table = write_pandas_partial()
except ProgrammingError as e:
self._run_query(f"drop table if exists {temp_table_name}")
_logger.warning(
f"Cannot create dataframe using specified schema for database."
f"Falling back to inferring schema from pandas dataframe. Exception: {e}"
)
table = write_pandas_partial(
auto_create_table=True, table_type="temporary"
)
else:
table = write_pandas_partial(
auto_create_table=True, table_type="temporary"
)
set_api_call_source(table, "Session.create_dataframe[pandas]")

if _emit_ast:
Expand Down Expand Up @@ -3172,19 +3221,8 @@ def create_dataframe(
and all([field.datatype.is_primitive() for field in schema.fields])
):
temp_table_name = random_name_for_temp_object(TempObjectType.TABLE)
schema_string = analyzer_utils.attribute_to_schema_string(
schema._to_attributes()
)
try:
self._run_query(
f"CREATE SCOPED TEMP TABLE {temp_table_name} ({schema_string})"
)
if self._initialize_temp_table_with_schema(temp_table_name, schema):
schema_query = f"SELECT * FROM {self.get_fully_qualified_name_if_possible(temp_table_name)}"
except ProgrammingError as e:
_logger.debug(
f"Cannot create temp table for specified non-nullable schema, fall back to using schema "
f"string from select query. Exception: {str(e)}"
)
else:
if not data:
raise ValueError("Cannot infer schema from empty data")
Expand Down
43 changes: 43 additions & 0 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,49 @@ def test_create_dataframe_with_pandas_df(session):
assert df.schema[2].datatype == TimestampType(TimestampTimeZone.LTZ)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="TODO: support local testing.",
)
@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
def test_create_dataframe_with_pandas_df_enforce_schema(session):
pdf = pd.DataFrame(
{
"col1": ["a1", "b1"],
"col2": ["a2", "b2"],
"col3": ["a3", "b3"],
}
)
user_schema = StructType(
[
StructField('"col1"', StringType()),
StructField('"col2"', StringType(20)),
StructField('"col3"', VariantType(), nullable=False),
]
)
df = session.create_dataframe(pdf, schema=user_schema)
assert df.schema == user_schema

# present a schema that does not match given data and
# ensure we fall back to old behavior w/o failures or side-effects
bad_schema = StructType(
[
StructField('"col1"', IntegerType()),
StructField('"col2"', StringType(20)),
StructField('"col3"', VariantType(), nullable=False),
]
)
expected_schema = StructType(
[
StructField('"col1"', StringType()),
StructField('"col2"', StringType()),
StructField('"col3"', StringType()),
]
)
df = session.create_dataframe(pdf, schema=bad_schema)
assert df.schema == expected_schema


def test_create_dataframe_with_dict(session):
data = {f"snow_{idx + 1}": idx**3 for idx in range(5)}
expected_names = [name.upper() for name in data.keys()]
Expand Down