Skip to content

Commit 650b56e

Browse files
committed
Support user-provided schema for Parquet reads
Allow users to specify a custom schema when reading Parquet files via `session.read.schema(schema).parquet(path)`. Previously only JSON and XML supported user-provided schemas; Parquet was blocked by a ValueError gate. - Add "parquet" to the user-schema format allowlist in _read_semi_structured_file - Update schema() docstring to reflect all supported formats - Remove stale ValueError assertion in test_read_parquet_with_no_schema - Add test_read_parquet_user_input_schema covering matching schema, missing columns, extra columns, and wrong-type error cases Made-with: Cursor
1 parent c29aa25 commit 650b56e

File tree

2 files changed

+64
-7
lines changed

2 files changed

+64
-7
lines changed

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -642,10 +642,10 @@ def table(
642642

643643
@publicapi
644644
def schema(self, schema: StructType, _emit_ast: bool = True) -> "DataFrameReader":
645-
"""Define the schema for CSV or XML files that you want to read.
645+
"""Define the schema for CSV, JSON, Parquet, or XML files that you want to read.
646646
647647
Args:
648-
schema: Schema configuration for the CSV or XML file to be read.
648+
schema: Schema configuration for the file to be read.
649649
650650
Returns:
651651
a :class:`DataFrameReader` instance with the specified schema configuration for the data to be read.
@@ -1596,7 +1596,7 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
15961596
raise_error=NotImplementedError,
15971597
)
15981598

1599-
if self._user_schema and format.lower() not in ["json", "xml"]:
1599+
if self._user_schema and format.lower() not in ["json", "xml", "parquet"]:
16001600
raise ValueError(f"Read {format} does not support user schema")
16011601
if (
16021602
self._user_schema

tests/integ/scala/test_dataframe_reader_suite.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,10 +1407,6 @@ def test_read_parquet_with_no_schema(session, mode):
14071407
res = df1.where(col('"num"') > 1).collect()
14081408
assert res == [Row(str="str2", num=2)]
14091409

1410-
# assert user cannot input a schema to read json
1411-
with pytest.raises(ValueError):
1412-
get_reader(session, mode).schema(user_schema).parquet(path)
1413-
14141410
# user can input customized formatTypeOptions
14151411
df2 = get_reader(session, mode).option("COMPRESSION", "NONE").parquet(path)
14161412
res = df2.collect()
@@ -1420,6 +1416,67 @@ def test_read_parquet_with_no_schema(session, mode):
14201416
]
14211417

14221418

1419+
@pytest.mark.skipif(
1420+
"config.getoption('local_testing_mode', default=False)",
1421+
reason="FEAT: parquet not supported",
1422+
)
1423+
def test_read_parquet_user_input_schema(session):
1424+
test_file = f"@{tmp_stage_name1}/{test_file_parquet}"
1425+
1426+
# Read with matching schema
1427+
schema = StructType(
1428+
[
1429+
StructField("str", StringType(), True),
1430+
StructField("num", LongType(), True),
1431+
]
1432+
)
1433+
df = session.read.schema(schema).parquet(test_file)
1434+
Utils.check_answer(df, [Row(str="str1", num=1), Row(str="str2", num=2)])
1435+
1436+
# Schema with a column not present in the file (should return None)
1437+
schema = StructType(
1438+
[
1439+
StructField("str", StringType(), True),
1440+
StructField("not_included_column", StringType(), True),
1441+
]
1442+
)
1443+
df = session.read.schema(schema).parquet(test_file)
1444+
Utils.check_answer(
1445+
df,
1446+
[
1447+
Row(str="str1", not_included_column=None),
1448+
Row(str="str2", not_included_column=None),
1449+
],
1450+
)
1451+
1452+
# Schema with an extra column beyond what the file has
1453+
schema = StructType(
1454+
[
1455+
StructField("str", StringType(), True),
1456+
StructField("num", LongType(), True),
1457+
StructField("extra_column", StringType(), True),
1458+
]
1459+
)
1460+
df = session.read.schema(schema).parquet(test_file)
1461+
Utils.check_answer(
1462+
df,
1463+
[
1464+
Row(str="str1", num=1, extra_column=None),
1465+
Row(str="str2", num=2, extra_column=None),
1466+
],
1467+
)
1468+
1469+
# Schema with wrong datatype (should fail with cast error)
1470+
schema = StructType(
1471+
[
1472+
StructField("str", IntegerType(), True),
1473+
StructField("num", LongType(), True),
1474+
]
1475+
)
1476+
with pytest.raises(SnowparkSQLException, match="Failed to cast variant value"):
1477+
session.read.schema(schema).parquet(test_file).collect()
1478+
1479+
14231480
@pytest.mark.skipif(
14241481
"config.getoption('local_testing_mode', default=False)",
14251482
reason="FEAT: parquet not supported",

0 commit comments

Comments
 (0)