Skip to content
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
- `try_to_binary`

- Added `Catalog` class to manage snowflake objects. It can be accessed via `Session.catalog`.
- Allow user input schema when reading JSON file on stage.
- Added support for specifying a schema string (including implicit struct syntax) when calling `DataFrame.create_dataframe`.

#### Improvements
Expand Down
9 changes: 5 additions & 4 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,7 @@ def read_file(
transformations: Optional[List[str]] = None,
metadata_project: Optional[List[str]] = None,
metadata_schema: Optional[List[Attribute]] = None,
use_user_schema: bool = False,
):
format_type_options, copy_options = get_copy_into_table_options(options)
format_type_options = self._merge_file_format_options(
Expand All @@ -1215,11 +1216,11 @@ def read_file(
pattern = options.get("PATTERN")
# Can only infer the schema for parquet, orc and avro
# csv and json in preview
infer_schema = (
schema_available = (
options.get("INFER_SCHEMA", True)
if format in INFER_SCHEMA_FORMAT_TYPES
else False
)
) or use_user_schema
# tracking usage of pattern, will refactor this function in future
if pattern:
self.session._conn._telemetry_client.send_copy_pattern_telemetry()
Expand Down Expand Up @@ -1262,7 +1263,7 @@ def read_file(
)
)

if infer_schema:
if schema_available:
assert schema_to_cast is not None
schema_project: List[str] = schema_cast_named(schema_to_cast)
else:
Expand Down Expand Up @@ -1302,7 +1303,7 @@ def read_file(
# If we have inferred the schema, we want to use those column names
temp_table_schema = (
schema
if infer_schema
if schema_available
else [
Attribute(f'"COL{index}"', att.datatype, att.nullable)
for index, att in enumerate(schema)
Expand Down
50 changes: 47 additions & 3 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
)
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import set_api_call_source
from snowflake.snowpark._internal.type_utils import ColumnOrName, convert_sf_to_sp_type
from snowflake.snowpark._internal.type_utils import (
ColumnOrName,
convert_sf_to_sp_type,
convert_sp_to_sf_type,
)
from snowflake.snowpark._internal.utils import (
INFER_SCHEMA_FORMAT_TYPES,
SNOWFLAKE_PATH_PREFIXES,
Expand Down Expand Up @@ -907,6 +911,32 @@ def _infer_schema_for_file_format(

return new_schema, schema_to_cast, read_file_transformations, None

def _get_schema_from_user_input(
self, user_schema: StructType
) -> Tuple[List, List, List]:
"""This function accept a user input structtype and return schemas needed for reading semi-structured file"""
schema_to_cast = []
transformations = []
new_schema = []
for field in user_schema.fields:
name = quote_name_without_upper_casing(field._name)
new_schema.append(
Attribute(
name,
field.datatype,
field.nullable,
)
)
identifier = f"$1:{name}::{convert_sp_to_sf_type(field.datatype)}"
schema_to_cast.append((identifier, field._name))
transformations.append(sql_expr(identifier))
self._user_schema = StructType._from_attributes(new_schema)
# If the user sets transformations, we should not override this
self._infer_schema_transformations = transformations
self._infer_schema_target_columns = self._user_schema.names
read_file_transformations = [t._expression.sql for t in transformations]
return new_schema, schema_to_cast, read_file_transformations

def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
if isinstance(self._session._conn, MockServerConnection):
if self._session._conn.is_closed():
Expand All @@ -922,7 +952,7 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
raise_error=NotImplementedError,
)

if self._user_schema:
if self._user_schema and format.lower() != "json":
raise ValueError(f"Read {format} does not support user schema")
path = _validate_stage_path(path)
self._file_path = path
Expand All @@ -931,7 +961,19 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
schema = [Attribute('"$1"', VariantType())]
read_file_transformations = None
schema_to_cast = None
if self._infer_schema:
use_user_schema = False

if self._user_schema:
(
new_schema,
schema_to_cast,
read_file_transformations,
) = self._get_schema_from_user_input(self._user_schema)
schema = new_schema
self._cur_options["INFER_SCHEMA"] = False
use_user_schema = True

elif self._infer_schema:
(
new_schema,
schema_to_cast,
Expand All @@ -957,6 +999,7 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
transformations=read_file_transformations,
metadata_project=metadata_project,
metadata_schema=metadata_schema,
use_user_schema=use_user_schema,
),
analyzer=self._session._analyzer,
),
Expand All @@ -975,6 +1018,7 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
transformations=read_file_transformations,
metadata_project=metadata_project,
metadata_schema=metadata_schema,
use_user_schema=use_user_schema,
),
)
df._reader = self
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/mock/_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def read_file(
transformations: Optional[List[str]] = None,
metadata_project: Optional[List[str]] = None,
metadata_schema: Optional[List[Attribute]] = None,
use_user_schema: bool = False,
) -> MockExecutionPlan:
if format.lower() not in SUPPORT_READ_OPTIONS.keys():
LocalTestOOBTelemetryService.get_instance().log_not_supported_error(
Expand Down
69 changes: 57 additions & 12 deletions tests/integ/scala/test_dataframe_reader_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,10 +1086,6 @@ def test_read_json_with_no_schema(session, mode, resources_path):
Row('{\n "color": "Red",\n "fruit": "Apple",\n "size": "Large"\n}')
]

# assert user cannot input a schema to read json
with pytest.raises(ValueError):
get_reader(session, mode).schema(user_schema).json(json_path)

# user can input customized formatTypeOptions
df2 = get_reader(session, mode).option("FILE_EXTENSION", "json").json(json_path)
assert df2.collect() == [
Expand All @@ -1116,10 +1112,6 @@ def test_read_json_with_no_schema(session, mode, resources_path):
Row('{\n "color": "Red",\n "fruit": "Apple",\n "size": "Large"\n}')
]

# assert user cannot input a schema to read json
with pytest.raises(ValueError):
get_reader(session, mode).schema(user_schema).json(json_path)

# assert local directory is invalid
with pytest.raises(
ValueError, match="DataFrameReader can only read files from stage locations."
Expand All @@ -1139,10 +1131,6 @@ def test_read_json_with_infer_schema(session, mode):
res = df1.where(col('"color"') == lit("Red")).collect()
assert res == [Row(color="Red", fruit="Apple", size="Large")]

# assert user cannot input a schema to read json
with pytest.raises(ValueError):
get_reader(session, mode).schema(user_schema).json(json_path)

# user can input customized formatTypeOptions
df2 = (
get_reader(session, mode)
Expand Down Expand Up @@ -1777,3 +1765,60 @@ def test_filepath_with_single_quote(session):
)

assert result1 == result2


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="read json not supported in localtesting",
)
def test_read_json_user_input_schema(session):
test_file = f"@{tmp_stage_name1}/{test_file_json}"

schema = StructType(
[
StructField("fruit", StringType(), True),
StructField("size", StringType(), True),
StructField("color", StringType(), True),
]
)

df = session.read.schema(schema).json(test_file)
Utils.check_answer(df, [Row(fruit="Apple", size="Large", color="Red")])

# schema that have part of column in file and column not in the file
schema = StructType(
[
StructField("fruit", StringType(), True),
StructField("size", StringType(), True),
StructField("not_included_column", StringType(), True),
]
)

df = session.read.schema(schema).json(test_file)
Utils.check_answer(df, [Row(fruit="Apple", size="Large", not_included_column=None)])

# schema that have extra column
schema = StructType(
[
StructField("fruit", StringType(), True),
StructField("size", StringType(), True),
StructField("color", StringType(), True),
StructField("extra_column", StringType(), True),
]
)

df = session.read.schema(schema).json(test_file)
Utils.check_answer(
df, [Row(fruit="Apple", size="Large", color="Red", extra_column=None)]
)

# schema that have false datatype
schema = StructType(
[
StructField("fruit", StringType(), True),
StructField("size", StringType(), True),
StructField("color", IntegerType(), True),
]
)
with pytest.raises(SnowparkSQLException, match="Failed to cast variant value"):
session.read.schema(schema).json(test_file).collect()
Loading