Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#### New Features

- Allow user input schema when reading Parquet file on stage.

#### Bug Fixes

#### Improvements
Expand Down
6 changes: 3 additions & 3 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,10 +642,10 @@ def table(

@publicapi
def schema(self, schema: StructType, _emit_ast: bool = True) -> "DataFrameReader":
"""Define the schema for CSV or XML files that you want to read.
"""Define the schema for CSV, JSON, Parquet, or XML files that you want to read.
Args:
schema: Schema configuration for the CSV or XML file to be read.
schema: Schema configuration for the file to be read.
Returns:
a :class:`DataFrameReader` instance with the specified schema configuration for the data to be read.
Expand Down Expand Up @@ -1596,7 +1596,7 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
raise_error=NotImplementedError,
)

if self._user_schema and format.lower() not in ["json", "xml"]:
if self._user_schema and format.lower() not in ["json", "xml", "parquet"]:
raise ValueError(f"Read {format} does not support user schema")
if (
self._user_schema
Expand Down
66 changes: 62 additions & 4 deletions tests/integ/scala/test_dataframe_reader_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,10 +1407,6 @@ def test_read_parquet_with_no_schema(session, mode):
res = df1.where(col('"num"') > 1).collect()
assert res == [Row(str="str2", num=2)]

# assert user cannot input a schema to read json
with pytest.raises(ValueError):
get_reader(session, mode).schema(user_schema).parquet(path)

# user can input customized formatTypeOptions
df2 = get_reader(session, mode).option("COMPRESSION", "NONE").parquet(path)
res = df2.collect()
Expand All @@ -1420,6 +1416,68 @@ def test_read_parquet_with_no_schema(session, mode):
]


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="FEAT: parquet not supported",
)
@pytest.mark.parametrize("mode", ["select", "copy"])
def test_read_parquet_user_input_schema(session, mode):
test_file = f"@{tmp_stage_name1}/{test_file_parquet}"

# Read with matching schema
schema = StructType(
[
StructField("str", StringType(), True),
StructField("num", LongType(), True),
]
)
df = get_reader(session, mode).schema(schema).parquet(test_file)
Utils.check_answer(df, [Row(str="str1", num=1), Row(str="str2", num=2)])

# Schema with a column not present in the file (should return None)
schema = StructType(
[
StructField("str", StringType(), True),
StructField("not_included_column", StringType(), True),
]
)
df = get_reader(session, mode).schema(schema).parquet(test_file)
Utils.check_answer(
df,
[
Row(str="str1", not_included_column=None),
Row(str="str2", not_included_column=None),
],
)

# Schema with an extra column beyond what the file has
schema = StructType(
[
StructField("str", StringType(), True),
StructField("num", LongType(), True),
StructField("extra_column", StringType(), True),
]
)
df = get_reader(session, mode).schema(schema).parquet(test_file)
Utils.check_answer(
df,
[
Row(str="str1", num=1, extra_column=None),
Row(str="str2", num=2, extra_column=None),
],
)

# Schema with wrong datatype (should fail with cast error)
schema = StructType(
[
StructField("str", IntegerType(), True),
StructField("num", LongType(), True),
]
)
with pytest.raises(SnowparkSQLException, match="Failed to cast variant value"):
get_reader(session, mode).schema(schema).parquet(test_file).collect()


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="FEAT: parquet not supported",
Expand Down
Loading