Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

- Invoking snowflake system procedures does not invoke an additional `describe procedure` call to check the return type of the procedure.
- Added support for `Session.create_dataframe()` with the stage URL and FILE data type.
- Added support for different modes for dealing with corrupt XML records when reading an XML file using `session.read.option('rowTag', <tag_name>).xml(<stage_file_path>)`. Currently `PERMISSIVE`, `DROPMALFORMED` and `FAILFAST` are supported.

#### Bug Fixes

Expand Down
7 changes: 7 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,12 @@ def _create_xml_query(
worker_column_name = "WORKER"
xml_row_number_column_name = "XML_ROW_NUMBER"
row_tag = options[XML_ROW_TAG_STRING]
mode = options.get("MODE", "PERMISSIVE").upper()

if mode not in {"PERMISSIVE", "DROPMALFORMED", "FAILFAST"}:
raise ValueError(
f"Invalid mode: {mode}. Must be one of PERMISSIVE, DROPMALFORMED, FAILFAST."
)

# TODO SNOW-1983360: make it an configurable option once the UDTF scalability issue is resolved.
# Currently it's capped at 16.
Expand All @@ -1395,6 +1401,7 @@ def _create_xml_query(
lit(num_workers),
lit(row_tag),
col(worker_column_name),
lit(mode),
),
)

Expand Down
69 changes: 47 additions & 22 deletions src/snowflake/snowpark/_internal/xml_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import os
import re
import html.entities
import logging
import struct
import xml.etree.ElementTree as ET
from typing import Optional, Dict, Any, Iterator, BinaryIO, Union, Tuple
from snowflake.snowpark.files import SnowflakeFile


DEFAULT_CHUNK_SIZE: int = 1024
VARIANT_COLUMN_SIZE_LIMIT: int = 16 * 1024 * 1024
COLUMN_NAME_OF_CORRUPT_RECORD = "columnNameOfCorruptRecord"


def replace_entity(match: re.Match) -> str:
Expand Down Expand Up @@ -76,7 +77,7 @@ def tag_is_self_closing(
chunk_start_pos = file_obj.tell()
chunk = file_obj.read(chunk_size)
if not chunk:
raise EOFError("EOF reached before end of opening tag")
raise EOFError("Reached end of file but the tag is not closed")

for idx, b in enumerate(struct.unpack(f"{len(chunk)}c", chunk)):
# '>' inside quote should not be considered as the end of the tag
Expand Down Expand Up @@ -216,7 +217,7 @@ def find_next_opening_tag_pos(
# Calculate the absolute position. Note that `data` starts at (current_pos - len(overlap)).
absolute_pos = current_pos + pos - len(overlap)
if absolute_pos >= end_limit:
raise EOFError("Found tag beyond end limit")
raise EOFError("Exceeded end limit before finding opening tag")
file_obj.seek(absolute_pos)
return absolute_pos

Expand Down Expand Up @@ -298,6 +299,7 @@ def process_xml_range(
tag_name: str,
approx_start: int,
approx_end: int,
mode: str,
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> Iterator[Optional[Dict[str, Any]]]:
"""
Expand All @@ -316,6 +318,8 @@ def process_xml_range(
tag_name (str): The tag that delimits records (e.g., "row").
approx_start (int): Approximate start byte position.
approx_end (int): Approximate end byte position.
mode (str): The mode for dealing with corrupt records.
"PERMISSIVE", "DROPMALFORMED" and "FAILFAST" are supported.
chunk_size (int): Size of chunks to read.

Yields:
Expand Down Expand Up @@ -351,8 +355,19 @@ def process_xml_range(
# decide whether the row element is self‑closing
try:
is_self_close, tag_end = tag_is_self_closing(f)
except EOFError:
# malformed XML record
# encountering an EOFError means the XML record isn't self-closing or
# doesn't have a closing tag after reaching the end of the file
except EOFError as e:
if mode == "PERMISSIVE":
# read util the end of file or util variant column size limit
record_bytes = f.read(VARIANT_COLUMN_SIZE_LIMIT)
record_str = record_bytes.decode("utf-8", errors="replace")
record_str = re.sub(r"&(\w+);", replace_entity, record_str)
yield {COLUMN_NAME_OF_CORRUPT_RECORD: record_str}
elif mode == "FAILFAST":
raise EOFError(
f"Malformed XML record at bytes {record_start}-EOF: {e}"
) from e
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears the DROPMALFORMED mode handling is missing in this error handling block. When mode == "DROPMALFORMED", the code should simply break or continue without yielding anything to properly skip the malformed record. This would align with the behavior described in the documentation where "DROPMALFORMED: Ignores the whole record that cannot be parsed correctly."

Suggested change
except EOFError as e:
if mode == "PERMISSIVE":
# read util the end of file or util variant column size limit
record_bytes = f.read(VARIANT_COLUMN_SIZE_LIMIT)
record_str = record_bytes.decode("utf-8", errors="replace")
record_str = re.sub(r"&(\w+);", replace_entity, record_str)
yield {COLUMN_NAME_OF_CORRUPT_RECORD: record_str}
elif mode == "FAILFAST":
raise EOFError(
f"Malformed XML record at bytes {record_start}-EOF: {e}"
) from e
except EOFError as e:
if mode == "PERMISSIVE":
# read util the end of file or util variant column size limit
record_bytes = f.read(VARIANT_COLUMN_SIZE_LIMIT)
record_str = record_bytes.decode("utf-8", errors="replace")
record_str = re.sub(r"&(\w+);", replace_entity, record_str)
yield {COLUMN_NAME_OF_CORRUPT_RECORD: record_str}
elif mode == "DROPMALFORMED":
# Skip the malformed record
continue
elif mode == "FAILFAST":
raise EOFError(
f"Malformed XML record at bytes {record_start}-EOF: {e}"
) from e

Spotted by Diamond

Is this helpful? React 👍 or 👎 to let us know.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed as we just ignore the malformed record here on DROPMALFORMED mode

break

if is_self_close:
Expand All @@ -361,31 +376,37 @@ def process_xml_range(
f.seek(tag_end)
try:
record_end = find_next_closing_tag_pos(f, closing_tag, chunk_size)
except EOFError:
# incomplete XML record
# encountering an EOFError means the XML record isn't self-closing or
# doesn't have a closing tag after reaching the end of the file
except EOFError as e:
if mode == "PERMISSIVE":
# read util the end of file or util variant column size limit
record_bytes = f.read(VARIANT_COLUMN_SIZE_LIMIT)
record_str = record_bytes.decode("utf-8", errors="replace")
record_str = re.sub(r"&(\w+);", replace_entity, record_str)
yield {COLUMN_NAME_OF_CORRUPT_RECORD: record_str}
elif mode == "FAILFAST":
raise EOFError(
f"Malformed XML record at bytes {record_start}-EOF: {e}"
) from e
break

# Read the complete XML record.
f.seek(record_start)
record_bytes = f.read(record_end - record_start)
try:
record_str = record_bytes.decode("utf-8")
record_str = re.sub(r"&(\w+);", replace_entity, record_str)
except UnicodeDecodeError as e:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually don't need to handle UnicodeDecodeError because we can simply replace the char that isn't supported by charset. We will have another PR to support different charset other than utf-8.

logging.warning(
f"Unicode decode error at bytes {record_start}-{record_end}: {e}"
)
f.seek(record_end)
continue
record_str = record_bytes.decode("utf-8", errors="replace")
record_str = re.sub(r"&(\w+);", replace_entity, record_str)

try:
element = ET.fromstring(record_str)
yield element_to_dict(strip_namespaces(element))
except ET.ParseError as e:
logging.warning(
f"XML parse error at bytes {record_start}-{record_end}: {e}"
)
logging.warning(f"Record content: {record_str}")
if mode == "PERMISSIVE":
yield {COLUMN_NAME_OF_CORRUPT_RECORD: record_str}
elif mode == "FAILFAST":
raise RuntimeError(
f"Malformed XML record at bytes {record_start}-{record_end}: {e}"
)

if record_end > approx_end:
break
Expand All @@ -395,7 +416,7 @@ def process_xml_range(


class XMLReader:
def process(self, filename: str, num_workers: int, row_tag: str, i: int):
def process(self, filename: str, num_workers: int, row_tag: str, i: int, mode: str):
"""
Splits the file into byte ranges—one per worker—by starting with an even
file size division and then moving each boundary to the end of a record,
Expand All @@ -406,10 +427,14 @@ def process(self, filename: str, num_workers: int, row_tag: str, i: int):
num_workers (int): Number of workers/chunks.
row_tag (str): The tag name that delimits records (e.g., "row").
i (int): The worker id.
mode (str): The mode for dealing with corrupt records.
"PERMISSIVE", "DROPMALFORMED" and "FAILFAST" are supported.
"""
file_size = get_file_size(filename)
approx_chunk_size = file_size // num_workers
approx_start = approx_chunk_size * i
approx_end = approx_chunk_size * (i + 1) if i < num_workers - 1 else file_size
for element in process_xml_range(filename, row_tag, approx_start, approx_end):
for element in process_xml_range(
filename, row_tag, approx_start, approx_end, mode
):
yield (element,)
25 changes: 25 additions & 0 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,31 @@ def xml(self, path: str, _emit_ast: bool = True) -> DataFrame:

Returns:
a :class:`DataFrame` that is set up to load data from the specified XML file(s) in a Snowflake stage.

Notes about reading XML files using a row tag:

- We support reading XML by specifying the element tag that represents a single record using the ``rowTag``
option. See Example 13 in :class:`DataFrameReader`.

- Each XML record is flattened into a single row, with each XML element or attribute mapped to a column.
All columns are represented with the variant type to accommodate heterogeneous or nested data. Therefore,
every column value has a size limit due to the variant type.

- The column names are derived from the XML element names. It will always be wrapped by single quotes.

- To parse the nested XML under a row tag, you can use dot notation ``.`` to query the nested fields in
a DataFrame. See Example 13 in :class:`DataFrameReader`.

- When ``rowTag`` is specified, the following options are supported for reading XML files
via :meth:`option()` or :meth:`options()`:

+ ``mode``: Specifies the mode of for dealing with corrupt XML records. The default value is ``PERMISSIVE``. The supported values are:

- ``PERMISSIVE``: When it encounters a corrupt record, it sets all fields to null and includes a 'columnNameOfCorruptRecord' column.

- ``DROPMALFORMED``: Ignores the whole record that cannot be parsed correctly.

- ``FAILFAST``: When it encounters a corrupt record, it raises an exception immediately.
"""
df = self._read_semi_structured_file(path, "XML")

Expand Down
74 changes: 74 additions & 0 deletions tests/integ/scala/test_dataframe_reader_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@
test_file_house_large_xml = "fias_house.large.xml"
test_file_xxe_xml = "xxe.xml"
test_file_nested_xml = "nested.xml"
test_file_malformed_no_closing_tag_xml = "malformed_no_closing_tag.xml"
test_file_malformed_not_self_closing_xml = "malformed_not_self_closing.xml"
test_file_malformed_record_xml = "malformed_record.xml"


# In the tests below, we test both scenarios: SELECT & COPY
Expand Down Expand Up @@ -261,6 +264,24 @@ def setup(session, resources_path, local_testing_mode):
Utils.upload_to_stage(
session, "@" + tmp_stage_name1, test_files.test_nested_xml, compress=False
)
Utils.upload_to_stage(
session,
"@" + tmp_stage_name1,
test_files.test_malformed_no_closing_tag_xml,
compress=False,
)
Utils.upload_to_stage(
session,
"@" + tmp_stage_name1,
test_files.test_malformed_not_self_closing_xml,
compress=False,
)
Utils.upload_to_stage(
session,
"@" + tmp_stage_name1,
test_files.test_malformed_record_xml,
compress=False,
)
Utils.upload_to_stage(
session, "@" + tmp_stage_name2, test_files.test_file_csv, compress=False
)
Expand Down Expand Up @@ -2019,3 +2040,56 @@ def test_read_xml_non_existing_file(session):
session.read.option("rowTag", row_tag).xml(
f"@{tmp_stage_name1}/non_existing_file.xml"
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="xml not supported in local testing mode",
)
@pytest.mark.skipif(
IS_IN_STORED_PROC,
reason="SNOW-2044853: Flaky in stored procedure test",
)
@pytest.mark.parametrize(
"file",
(
test_file_malformed_no_closing_tag_xml,
test_file_malformed_not_self_closing_xml,
test_file_malformed_record_xml,
),
)
def test_read_malformed_xml(session, file):
row_tag = "record"

# permissive mode
df = (
session.read.option("rowTag", row_tag)
.option("mode", "permissive")
.xml(f"@{tmp_stage_name1}/{file}")
)
result = df.collect()
assert len(result) == 2
assert len(result[0]) == 4 # has another column 'columnNameOfCorruptRecord'
assert (
result[0]["'columnNameOfCorruptRecord'"] is not None
or result[1]["'columnNameOfCorruptRecord'"] is not None
)

# dropmalformed mode
df = (
session.read.option("rowTag", row_tag)
.option("mode", "dropmalformed")
.xml(f"@{tmp_stage_name1}/{test_file_malformed_no_closing_tag_xml}")
)
result = df.collect()
assert len(result) == 1
assert len(result[0]) == 3

# failfast mode
df = (
session.read.option("rowTag", row_tag)
.option("mode", "failfast")
.xml(f"@{tmp_stage_name1}/{test_file_malformed_no_closing_tag_xml}")
)
with pytest.raises(SnowparkSQLException, match="Malformed XML record at bytes"):
df.collect()
11 changes: 11 additions & 0 deletions tests/resources/malformed_no_closing_tag.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<test>
<record>
<id>41</id>
<name>Joe Biden</name>
<email>joe@example.com</email>
</record>
<record>
<id>42</id>
<name>Jane Doe</name>
<email>jane@example.com</email>
</test>
4 changes: 4 additions & 0 deletions tests/resources/malformed_not_self_closing.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<test>
<record id="41" name="Joe Biden" email="joe@example.com" />
<record id="42" name="Jane Doe" email="jane@example.com"
</test>
12 changes: 12 additions & 0 deletions tests/resources/malformed_record.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<test>
<record>
<id>41</id>
<name>Joe Biden</name>
<email>joe@example.com</email>
</record>
<record>
<id>42</id>
<name>Jane Doe</name>
<email>jane@example.com</email
</record>
</test>
3 changes: 3 additions & 0 deletions tests/unit/scala/test_utils_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ def check_zip_files_and_close_stream(input_stream, expected_files):
"resources/fias_house.xml",
"resources/fias_house.large.xml",
"resources/iris.csv",
"resources/malformed_no_closing_tag.xml",
"resources/malformed_not_self_closing.xml",
"resources/malformed_record.xml",
"resources/nested.xml",
"resources/test.avro",
"resources/test.orc",
Expand Down
12 changes: 12 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,18 @@ def test_xxe_xml(self):
def test_nested_xml(self):
return os.path.join(self.resources_path, "nested.xml")

@property
def test_malformed_no_closing_tag_xml(self):
return os.path.join(self.resources_path, "malformed_no_closing_tag.xml")

@property
def test_malformed_not_self_closing_xml(self):
return os.path.join(self.resources_path, "malformed_not_self_closing.xml")

@property
def test_malformed_record_xml(self):
return os.path.join(self.resources_path, "malformed_record.xml")

@property
def test_dog_image(self):
return os.path.join(self.resources_path, "dog.jpg")
Expand Down
Loading