Skip to content
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- `nullifzero`
- `snowflake_cortex_sentiment`
- Added `Catalog` class to manage snowflake objects. It can be accessed via `Session.catalog`.
- Allow user input schema when reading JSON file on stage.

#### Improvements

Expand Down
46 changes: 44 additions & 2 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
)
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import set_api_call_source
from snowflake.snowpark._internal.type_utils import ColumnOrName, convert_sf_to_sp_type
from snowflake.snowpark._internal.type_utils import (
ColumnOrName,
convert_sf_to_sp_type,
convert_sp_to_sf_type,
)
from snowflake.snowpark._internal.utils import (
INFER_SCHEMA_FORMAT_TYPES,
SNOWFLAKE_PATH_PREFIXES,
Expand Down Expand Up @@ -907,6 +911,35 @@ def _infer_schema_for_file_format(

return new_schema, schema_to_cast, read_file_transformations, None

def _infer_schema_from_user_input(
self, user_schema: StructType, format: str
) -> Tuple[List, List, List]:
if format.lower() != "json":
raise ValueError(
f"Currently only support user schema for JSON format, got {format} instead"
)
schema_to_cast = []
transformations = []
new_schema = []
for field in user_schema.fields:
name = quote_name_without_upper_casing(field._name)
new_schema.append(
Attribute(
name,
field.datatype,
field.nullable,
)
)
identifier = f"$1:{name}::{convert_sp_to_sf_type(field.datatype)}"
schema_to_cast.append((identifier, field._name))
transformations.append(sql_expr(identifier))
self._user_schema = StructType._from_attributes(new_schema)
# If the user sets transformations, we should not override this
self._infer_schema_transformations = transformations
self._infer_schema_target_columns = self._user_schema.names
read_file_transformations = [t._expression.sql for t in transformations]
return new_schema, schema_to_cast, read_file_transformations

def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
if isinstance(self._session._conn, MockServerConnection):
if self._session._conn.is_closed():
Expand All @@ -922,7 +955,7 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
raise_error=NotImplementedError,
)

if self._user_schema:
if self._user_schema and format.lower() != "json":
raise ValueError(f"Read {format} does not support user schema")
path = _validate_stage_path(path)
self._file_path = path
Expand All @@ -941,6 +974,15 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
if new_schema:
schema = new_schema

if self._user_schema and not self._infer_schema:
(
new_schema,
schema_to_cast,
read_file_transformations,
) = self._infer_schema_from_user_input(self._user_schema, format)
schema = new_schema
self._cur_options["INFER_SCHEMA"] = True

metadata_project, metadata_schema = self._get_metadata_project_and_schema()

if self._session.sql_simplifier_enabled:
Expand Down
43 changes: 31 additions & 12 deletions tests/integ/scala/test_dataframe_reader_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,10 +1086,6 @@ def test_read_json_with_no_schema(session, mode, resources_path):
Row('{\n "color": "Red",\n "fruit": "Apple",\n "size": "Large"\n}')
]

# assert user cannot input a schema to read json
with pytest.raises(ValueError):
get_reader(session, mode).schema(user_schema).json(json_path)

# user can input customized formatTypeOptions
df2 = get_reader(session, mode).option("FILE_EXTENSION", "json").json(json_path)
assert df2.collect() == [
Expand All @@ -1116,10 +1112,6 @@ def test_read_json_with_no_schema(session, mode, resources_path):
Row('{\n "color": "Red",\n "fruit": "Apple",\n "size": "Large"\n}')
]

# assert user cannot input a schema to read json
with pytest.raises(ValueError):
get_reader(session, mode).schema(user_schema).json(json_path)

# assert local directory is invalid
with pytest.raises(
ValueError, match="DataFrameReader can only read files from stage locations."
Expand All @@ -1139,10 +1131,6 @@ def test_read_json_with_infer_schema(session, mode):
res = df1.where(col('"color"') == lit("Red")).collect()
assert res == [Row(color="Red", fruit="Apple", size="Large")]

# assert user cannot input a schema to read json
with pytest.raises(ValueError):
get_reader(session, mode).schema(user_schema).json(json_path)

# user can input customized formatTypeOptions
df2 = (
get_reader(session, mode)
Expand Down Expand Up @@ -1777,3 +1765,34 @@ def test_filepath_with_single_quote(session):
)

assert result1 == result2


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="read json not supported in localtesting",
)
def test_read_json_user_input_schema(session):
test_file = f"@{tmp_stage_name1}/{test_file_json}"

schema = StructType(
[
StructField("fruit", StringType(), True),
StructField("size", StringType(), True),
StructField("color", StringType(), True),
]
)

df = session.read.schema(schema).json(test_file)
Utils.check_answer(df, [Row(fruit="Apple", size="Large", color="Red")])

schema = StructType(
[
StructField("fruit", StringType(), True),
StructField("size", StringType(), True),
StructField("not_included_column", StringType(), True),
]
)

df = session.read.schema(schema).json(test_file)
print(df.collect())
Utils.check_answer(df, [Row(fruit="Apple", size="Large", not_included_column=None)])
Loading