Skip to content

Commit 24d0ae8

Browse files
committed
fix test
1 parent 1c84a9f commit 24d0ae8

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,12 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
13841384

13851385
if self._user_schema and format.lower() not in ["json", "xml"]:
13861386
raise ValueError(f"Read {format} does not support user schema")
1387+
if (
1388+
self._user_schema
1389+
and format.lower() == "xml"
1390+
and XML_ROW_TAG_STRING not in self._cur_options
1391+
):
1392+
raise ValueError("When read XML with user schema, rowtag must be set.")
13871393
path = _validate_stage_path(path)
13881394
self._file_path = path
13891395
self._file_type = format

tests/integ/scala/test_dataframe_reader_suite.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1626,10 +1626,6 @@ def test_read_xml_with_no_schema(session, mode, resources_path):
16261626
res = df1.where(sql_expr("xmlget($1, 'num', 0):\"$\"") > 1).collect()
16271627
assert res == [Row("<test>\n <num>2</num>\n <str>str2</str>\n</test>")]
16281628

1629-
# assert user cannot input a schema to read json
1630-
with pytest.raises(ValueError):
1631-
get_reader(session, mode).schema(user_schema).xml(path)
1632-
16331629
# assert local directory is invalid
16341630
with pytest.raises(
16351631
ValueError, match="DataFrameReader can only read files from stage locations."

tests/integ/test_xml_reader_row_tag.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,3 +609,19 @@ def test_read_xml_with_custom_schema(session):
609609
]
610610
Utils.check_answer(df, expected_result)
611611
assert df.schema == expected_schema
612+
613+
614+
def test_user_schema_without_rowtag(session):
615+
user_schema = StructType(
616+
[
617+
StructField("Author", StringType(), True),
618+
StructField("Title", StringType(), True),
619+
StructField("genre", StringType(), True),
620+
StructField("PRICE", DoubleType(), True),
621+
StructField("publish_Date", DateType(), True),
622+
]
623+
)
624+
with pytest.raises(
625+
ValueError, match="When read XML with user schema, rowtag must be set."
626+
):
627+
session.read.schema(user_schema).xml(f"@{tmp_stage_name}/{test_file_books_xml}")

0 commit comments

Comments
 (0)