Skip to content

Commit 7413be9

Browse files
authored
SNOW-2082332: Support mode for dealing corrupt XML records (#3337)
1 parent 6846ef3 commit 7413be9

File tree

10 files changed

+218
-22
lines changed

10 files changed

+218
-22
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
- Invoking snowflake system procedures does not invoke an additional `describe procedure` call to check the return type of the procedure.
1010
- Added support for `Session.create_dataframe()` with the stage URL and FILE data type.
11+
- Added support for different modes for dealing with corrupt XML records when reading an XML file using `session.read.option('rowTag', <tag_name>).xml(<stage_file_path>)`. Currently `PERMISSIVE`, `DROPMALFORMED` and `FAILFAST` are supported.
1112
- Improved query generation for `Dataframe.drop` to use `SELECT * EXCLUDE ()` to exclude the dropped columns. To enable this feature, set `session.conf.set("use_simplified_query_generation", True)`.
1213

1314
#### Bug Fixes

src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,15 @@ def _create_xml_query(
13731373
worker_column_name = "WORKER"
13741374
xml_row_number_column_name = "XML_ROW_NUMBER"
13751375
row_tag = options[XML_ROW_TAG_STRING]
1376+
mode = options.get("MODE", "PERMISSIVE").upper()
1377+
column_name_of_corrupt_record = options.get(
1378+
"COLUMNNAMEOFCORRUPTRECORD", "_corrupt_record"
1379+
)
1380+
1381+
if mode not in {"PERMISSIVE", "DROPMALFORMED", "FAILFAST"}:
1382+
raise ValueError(
1383+
f"Invalid mode: {mode}. Must be one of PERMISSIVE, DROPMALFORMED, FAILFAST."
1384+
)
13761385

13771386
# TODO SNOW-1983360: make it an configurable option once the UDTF scalability issue is resolved.
13781387
# Currently it's capped at 16.
@@ -1395,6 +1404,8 @@ def _create_xml_query(
13951404
lit(num_workers),
13961405
lit(row_tag),
13971406
col(worker_column_name),
1407+
lit(mode),
1408+
lit(column_name_of_corrupt_record),
13981409
),
13991410
)
14001411

src/snowflake/snowpark/_internal/xml_reader.py

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import os
66
import re
77
import html.entities
8-
import logging
98
import struct
109
import xml.etree.ElementTree as ET
1110
from typing import Optional, Dict, Any, Iterator, BinaryIO, Union, Tuple
1211
from snowflake.snowpark.files import SnowflakeFile
1312

1413

1514
DEFAULT_CHUNK_SIZE: int = 1024
15+
VARIANT_COLUMN_SIZE_LIMIT: int = 16 * 1024 * 1024
1616

1717

1818
def replace_entity(match: re.Match) -> str:
@@ -76,7 +76,7 @@ def tag_is_self_closing(
7676
chunk_start_pos = file_obj.tell()
7777
chunk = file_obj.read(chunk_size)
7878
if not chunk:
79-
raise EOFError("EOF reached before end of opening tag")
79+
raise EOFError("Reached end of file but the tag is not closed")
8080

8181
for idx, b in enumerate(struct.unpack(f"{len(chunk)}c", chunk)):
8282
# '>' inside quote should not be considered as the end of the tag
@@ -216,7 +216,7 @@ def find_next_opening_tag_pos(
216216
# Calculate the absolute position. Note that `data` starts at (current_pos - len(overlap)).
217217
absolute_pos = current_pos + pos - len(overlap)
218218
if absolute_pos >= end_limit:
219-
raise EOFError("Found tag beyond end limit")
219+
raise EOFError("Exceeded end limit before finding opening tag")
220220
file_obj.seek(absolute_pos)
221221
return absolute_pos
222222

@@ -298,6 +298,8 @@ def process_xml_range(
298298
tag_name: str,
299299
approx_start: int,
300300
approx_end: int,
301+
mode: str,
302+
column_name_of_corrupt_record: str,
301303
chunk_size: int = DEFAULT_CHUNK_SIZE,
302304
) -> Iterator[Optional[Dict[str, Any]]]:
303305
"""
@@ -316,6 +318,9 @@ def process_xml_range(
316318
tag_name (str): The tag that delimits records (e.g., "row").
317319
approx_start (int): Approximate start byte position.
318320
approx_end (int): Approximate end byte position.
321+
mode (str): The mode for dealing with corrupt records.
322+
"PERMISSIVE", "DROPMALFORMED" and "FAILFAST" are supported.
323+
column_name_of_corrupt_record (str): The name of the column for corrupt records.
319324
chunk_size (int): Size of chunks to read.
320325
321326
Yields:
@@ -351,8 +356,19 @@ def process_xml_range(
351356
# decide whether the row element is self‑closing
352357
try:
353358
is_self_close, tag_end = tag_is_self_closing(f)
354-
except EOFError:
355-
# malformed XML record
359+
# encountering an EOFError means the XML record isn't self-closing or
360+
# doesn't have a closing tag after reaching the end of the file
361+
except EOFError as e:
362+
if mode == "PERMISSIVE":
363+
# read util the end of file or util variant column size limit
364+
record_bytes = f.read(VARIANT_COLUMN_SIZE_LIMIT)
365+
record_str = record_bytes.decode("utf-8", errors="replace")
366+
record_str = re.sub(r"&(\w+);", replace_entity, record_str)
367+
yield {column_name_of_corrupt_record: record_str}
368+
elif mode == "FAILFAST":
369+
raise EOFError(
370+
f"Malformed XML record at bytes {record_start}-EOF: {e}"
371+
) from e
356372
break
357373

358374
if is_self_close:
@@ -361,31 +377,37 @@ def process_xml_range(
361377
f.seek(tag_end)
362378
try:
363379
record_end = find_next_closing_tag_pos(f, closing_tag, chunk_size)
364-
except EOFError:
365-
# incomplete XML record
380+
# encountering an EOFError means the XML record isn't self-closing or
381+
# doesn't have a closing tag after reaching the end of the file
382+
except EOFError as e:
383+
if mode == "PERMISSIVE":
384+
# read util the end of file or util variant column size limit
385+
record_bytes = f.read(VARIANT_COLUMN_SIZE_LIMIT)
386+
record_str = record_bytes.decode("utf-8", errors="replace")
387+
record_str = re.sub(r"&(\w+);", replace_entity, record_str)
388+
yield {column_name_of_corrupt_record: record_str}
389+
elif mode == "FAILFAST":
390+
raise EOFError(
391+
f"Malformed XML record at bytes {record_start}-EOF: {e}"
392+
) from e
366393
break
367394

368395
# Read the complete XML record.
369396
f.seek(record_start)
370397
record_bytes = f.read(record_end - record_start)
371-
try:
372-
record_str = record_bytes.decode("utf-8")
373-
record_str = re.sub(r"&(\w+);", replace_entity, record_str)
374-
except UnicodeDecodeError as e:
375-
logging.warning(
376-
f"Unicode decode error at bytes {record_start}-{record_end}: {e}"
377-
)
378-
f.seek(record_end)
379-
continue
398+
record_str = record_bytes.decode("utf-8", errors="replace")
399+
record_str = re.sub(r"&(\w+);", replace_entity, record_str)
380400

381401
try:
382402
element = ET.fromstring(record_str)
383403
yield element_to_dict(strip_namespaces(element))
384404
except ET.ParseError as e:
385-
logging.warning(
386-
f"XML parse error at bytes {record_start}-{record_end}: {e}"
387-
)
388-
logging.warning(f"Record content: {record_str}")
405+
if mode == "PERMISSIVE":
406+
yield {column_name_of_corrupt_record: record_str}
407+
elif mode == "FAILFAST":
408+
raise RuntimeError(
409+
f"Malformed XML record at bytes {record_start}-{record_end}: {e}"
410+
)
389411

390412
if record_end > approx_end:
391413
break
@@ -395,7 +417,15 @@ def process_xml_range(
395417

396418

397419
class XMLReader:
398-
def process(self, filename: str, num_workers: int, row_tag: str, i: int):
420+
def process(
421+
self,
422+
filename: str,
423+
num_workers: int,
424+
row_tag: str,
425+
i: int,
426+
mode: str,
427+
column_name_of_corrupt_record: str,
428+
):
399429
"""
400430
Splits the file into byte ranges—one per worker—by starting with an even
401431
file size division and then moving each boundary to the end of a record,
@@ -406,10 +436,20 @@ def process(self, filename: str, num_workers: int, row_tag: str, i: int):
406436
num_workers (int): Number of workers/chunks.
407437
row_tag (str): The tag name that delimits records (e.g., "row").
408438
i (int): The worker id.
439+
mode (str): The mode for dealing with corrupt records.
440+
"PERMISSIVE", "DROPMALFORMED" and "FAILFAST" are supported.
441+
column_name_of_corrupt_record (str): The name of the column for corrupt records.
409442
"""
410443
file_size = get_file_size(filename)
411444
approx_chunk_size = file_size // num_workers
412445
approx_start = approx_chunk_size * i
413446
approx_end = approx_chunk_size * (i + 1) if i < num_workers - 1 else file_size
414-
for element in process_xml_range(filename, row_tag, approx_start, approx_end):
447+
for element in process_xml_range(
448+
filename,
449+
row_tag,
450+
approx_start,
451+
approx_end,
452+
mode,
453+
column_name_of_corrupt_record,
454+
):
415455
yield (element,)

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,34 @@ def xml(self, path: str, _emit_ast: bool = True) -> DataFrame:
846846
847847
Returns:
848848
a :class:`DataFrame` that is set up to load data from the specified XML file(s) in a Snowflake stage.
849+
850+
Notes about reading XML files using a row tag:
851+
852+
- We support reading XML by specifying the element tag that represents a single record using the ``rowTag``
853+
option. See Example 13 in :class:`DataFrameReader`.
854+
855+
- Each XML record is flattened into a single row, with each XML element or attribute mapped to a column.
856+
All columns are represented with the variant type to accommodate heterogeneous or nested data. Therefore,
857+
every column value has a size limit due to the variant type.
858+
859+
- The column names are derived from the XML element names. It will always be wrapped by single quotes.
860+
861+
- To parse the nested XML under a row tag, you can use dot notation ``.`` to query the nested fields in
862+
a DataFrame. See Example 13 in :class:`DataFrameReader`.
863+
864+
- When ``rowTag`` is specified, the following options are supported for reading XML files
865+
via :meth:`option()` or :meth:`options()`:
866+
867+
+ ``mode``: Specifies the mode for dealing with corrupt XML records. The default value is ``PERMISSIVE``. The supported values are:
868+
869+
- ``PERMISSIVE``: When it encounters a corrupt record, it sets all fields to null and includes a `columnNameOfCorruptRecord` column.
870+
871+
- ``DROPMALFORMED``: Ignores the whole record that cannot be parsed correctly.
872+
873+
- ``FAILFAST``: When it encounters a corrupt record, it raises an exception immediately.
874+
875+
+ ``columnNameOfCorruptRecord``: Specifies the name of the column that contains the corrupt record.
876+
The default value is '_corrupt_record'.
849877
"""
850878
df = self._read_semi_structured_file(path, "XML")
851879

tests/integ/scala/test_dataframe_reader_suite.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@
6969
test_file_house_large_xml = "fias_house.large.xml"
7070
test_file_xxe_xml = "xxe.xml"
7171
test_file_nested_xml = "nested.xml"
72+
test_file_malformed_no_closing_tag_xml = "malformed_no_closing_tag.xml"
73+
test_file_malformed_not_self_closing_xml = "malformed_not_self_closing.xml"
74+
test_file_malformed_record_xml = "malformed_record.xml"
7275

7376

7477
# In the tests below, we test both scenarios: SELECT & COPY
@@ -261,6 +264,24 @@ def setup(session, resources_path, local_testing_mode):
261264
Utils.upload_to_stage(
262265
session, "@" + tmp_stage_name1, test_files.test_nested_xml, compress=False
263266
)
267+
Utils.upload_to_stage(
268+
session,
269+
"@" + tmp_stage_name1,
270+
test_files.test_malformed_no_closing_tag_xml,
271+
compress=False,
272+
)
273+
Utils.upload_to_stage(
274+
session,
275+
"@" + tmp_stage_name1,
276+
test_files.test_malformed_not_self_closing_xml,
277+
compress=False,
278+
)
279+
Utils.upload_to_stage(
280+
session,
281+
"@" + tmp_stage_name1,
282+
test_files.test_malformed_record_xml,
283+
compress=False,
284+
)
264285
Utils.upload_to_stage(
265286
session, "@" + tmp_stage_name2, test_files.test_file_csv, compress=False
266287
)
@@ -2019,3 +2040,56 @@ def test_read_xml_non_existing_file(session):
20192040
session.read.option("rowTag", row_tag).xml(
20202041
f"@{tmp_stage_name1}/non_existing_file.xml"
20212042
)
2043+
2044+
2045+
@pytest.mark.skipif(
2046+
"config.getoption('local_testing_mode', default=False)",
2047+
reason="xml not supported in local testing mode",
2048+
)
2049+
@pytest.mark.skipif(
2050+
IS_IN_STORED_PROC,
2051+
reason="SNOW-2044853: Flaky in stored procedure test",
2052+
)
2053+
@pytest.mark.parametrize(
2054+
"file",
2055+
(
2056+
test_file_malformed_no_closing_tag_xml,
2057+
test_file_malformed_not_self_closing_xml,
2058+
test_file_malformed_record_xml,
2059+
),
2060+
)
2061+
def test_read_malformed_xml(session, file):
2062+
row_tag = "record"
2063+
2064+
# permissive mode
2065+
df = (
2066+
session.read.option("rowTag", row_tag)
2067+
.option("mode", "permissive")
2068+
.xml(f"@{tmp_stage_name1}/{file}")
2069+
)
2070+
result = df.collect()
2071+
assert len(result) == 2
2072+
assert len(result[0]) == 4 # has another column '_corrupt_record'
2073+
assert (
2074+
result[0]["'_corrupt_record'"] is not None
2075+
or result[1]["'_corrupt_record'"] is not None
2076+
)
2077+
2078+
# dropmalformed mode
2079+
df = (
2080+
session.read.option("rowTag", row_tag)
2081+
.option("mode", "dropmalformed")
2082+
.xml(f"@{tmp_stage_name1}/{test_file_malformed_no_closing_tag_xml}")
2083+
)
2084+
result = df.collect()
2085+
assert len(result) == 1
2086+
assert len(result[0]) == 3
2087+
2088+
# failfast mode
2089+
df = (
2090+
session.read.option("rowTag", row_tag)
2091+
.option("mode", "failfast")
2092+
.xml(f"@{tmp_stage_name1}/{test_file_malformed_no_closing_tag_xml}")
2093+
)
2094+
with pytest.raises(SnowparkSQLException, match="Malformed XML record at bytes"):
2095+
df.collect()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<test>
2+
<record>
3+
<id>41</id>
4+
<name>Joe Biden</name>
5+
<email>joe@example.com</email>
6+
</record>
7+
<record>
8+
<id>42</id>
9+
<name>Jane Doe</name>
10+
<email>jane@example.com</email>
11+
</test>
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<test>
2+
<record id="41" name="Joe Biden" email="joe@example.com" />
3+
<record id="42" name="Jane Doe" email="jane@example.com"
4+
</test>
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
<test>
2+
<record>
3+
<id>41</id>
4+
<name>Joe Biden</name>
5+
<email>joe@example.com</email>
6+
</record>
7+
<record>
8+
<id>42</id>
9+
<name>Jane Doe</name>
10+
<email>jane@example.com</email
11+
</record>
12+
</test>

tests/unit/scala/test_utils_suite.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ def check_zip_files_and_close_stream(input_stream, expected_files):
284284
"resources/fias_house.xml",
285285
"resources/fias_house.large.xml",
286286
"resources/iris.csv",
287+
"resources/malformed_no_closing_tag.xml",
288+
"resources/malformed_not_self_closing.xml",
289+
"resources/malformed_record.xml",
287290
"resources/nested.xml",
288291
"resources/test.avro",
289292
"resources/test.orc",

tests/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,6 +1554,18 @@ def test_xxe_xml(self):
15541554
def test_nested_xml(self):
15551555
return os.path.join(self.resources_path, "nested.xml")
15561556

1557+
@property
1558+
def test_malformed_no_closing_tag_xml(self):
1559+
return os.path.join(self.resources_path, "malformed_no_closing_tag.xml")
1560+
1561+
@property
1562+
def test_malformed_not_self_closing_xml(self):
1563+
return os.path.join(self.resources_path, "malformed_not_self_closing.xml")
1564+
1565+
@property
1566+
def test_malformed_record_xml(self):
1567+
return os.path.join(self.resources_path, "malformed_record.xml")
1568+
15571569
@property
15581570
def test_dog_image(self):
15591571
return os.path.join(self.resources_path, "dog.jpg")

0 commit comments

Comments
 (0)