Skip to content

Commit 95f2792

Browse files
committed
column name
1 parent 114e815 commit 95f2792

File tree

4 files changed

+31
-10
lines changed

4 files changed

+31
-10
lines changed

src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,6 +1374,9 @@ def _create_xml_query(
13741374
xml_row_number_column_name = "XML_ROW_NUMBER"
13751375
row_tag = options[XML_ROW_TAG_STRING]
13761376
mode = options.get("MODE", "PERMISSIVE").upper()
1377+
column_name_of_corrupt_record = options.get(
1378+
"COLUMNNAMEOFCORRUPTRECORD", "_corrupt_record"
1379+
)
13771380

13781381
if mode not in {"PERMISSIVE", "DROPMALFORMED", "FAILFAST"}:
13791382
raise ValueError(
@@ -1402,6 +1405,7 @@ def _create_xml_query(
14021405
lit(row_tag),
14031406
col(worker_column_name),
14041407
lit(mode),
1408+
lit(column_name_of_corrupt_record),
14051409
),
14061410
)
14071411

src/snowflake/snowpark/_internal/xml_reader.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
DEFAULT_CHUNK_SIZE: int = 1024
1515
VARIANT_COLUMN_SIZE_LIMIT: int = 16 * 1024 * 1024
16-
COLUMN_NAME_OF_CORRUPT_RECORD = "columnNameOfCorruptRecord"
1716

1817

1918
def replace_entity(match: re.Match) -> str:
@@ -300,6 +299,7 @@ def process_xml_range(
300299
approx_start: int,
301300
approx_end: int,
302301
mode: str,
302+
column_name_of_corrupt_record: str,
303303
chunk_size: int = DEFAULT_CHUNK_SIZE,
304304
) -> Iterator[Optional[Dict[str, Any]]]:
305305
"""
@@ -320,6 +320,7 @@ def process_xml_range(
320320
approx_end (int): Approximate end byte position.
321321
mode (str): The mode for dealing with corrupt records.
322322
"PERMISSIVE", "DROPMALFORMED" and "FAILFAST" are supported.
323+
column_name_of_corrupt_record (str): The name of the column for corrupt records.
323324
chunk_size (int): Size of chunks to read.
324325
325326
Yields:
@@ -363,7 +364,7 @@ def process_xml_range(
363364
record_bytes = f.read(VARIANT_COLUMN_SIZE_LIMIT)
364365
record_str = record_bytes.decode("utf-8", errors="replace")
365366
record_str = re.sub(r"&(\w+);", replace_entity, record_str)
366-
yield {COLUMN_NAME_OF_CORRUPT_RECORD: record_str}
367+
yield {column_name_of_corrupt_record: record_str}
367368
elif mode == "FAILFAST":
368369
raise EOFError(
369370
f"Malformed XML record at bytes {record_start}-EOF: {e}"
@@ -384,7 +385,7 @@ def process_xml_range(
384385
record_bytes = f.read(VARIANT_COLUMN_SIZE_LIMIT)
385386
record_str = record_bytes.decode("utf-8", errors="replace")
386387
record_str = re.sub(r"&(\w+);", replace_entity, record_str)
387-
yield {COLUMN_NAME_OF_CORRUPT_RECORD: record_str}
388+
yield {column_name_of_corrupt_record: record_str}
388389
elif mode == "FAILFAST":
389390
raise EOFError(
390391
f"Malformed XML record at bytes {record_start}-EOF: {e}"
@@ -402,7 +403,7 @@ def process_xml_range(
402403
yield element_to_dict(strip_namespaces(element))
403404
except ET.ParseError as e:
404405
if mode == "PERMISSIVE":
405-
yield {COLUMN_NAME_OF_CORRUPT_RECORD: record_str}
406+
yield {column_name_of_corrupt_record: record_str}
406407
elif mode == "FAILFAST":
407408
raise RuntimeError(
408409
f"Malformed XML record at bytes {record_start}-{record_end}: {e}"
@@ -416,7 +417,15 @@ def process_xml_range(
416417

417418

418419
class XMLReader:
419-
def process(self, filename: str, num_workers: int, row_tag: str, i: int, mode: str):
420+
def process(
421+
self,
422+
filename: str,
423+
num_workers: int,
424+
row_tag: str,
425+
i: int,
426+
mode: str,
427+
column_name_of_corrupt_record: str,
428+
):
420429
"""
421430
Splits the file into byte ranges—one per worker—by starting with an even
422431
file size division and then moving each boundary to the end of a record,
@@ -429,12 +438,18 @@ def process(self, filename: str, num_workers: int, row_tag: str, i: int, mode: s
429438
i (int): The worker id.
430439
mode (str): The mode for dealing with corrupt records.
431440
"PERMISSIVE", "DROPMALFORMED" and "FAILFAST" are supported.
441+
column_name_of_corrupt_record (str): The name of the column for corrupt records.
432442
"""
433443
file_size = get_file_size(filename)
434444
approx_chunk_size = file_size // num_workers
435445
approx_start = approx_chunk_size * i
436446
approx_end = approx_chunk_size * (i + 1) if i < num_workers - 1 else file_size
437447
for element in process_xml_range(
438-
filename, row_tag, approx_start, approx_end, mode
448+
filename,
449+
row_tag,
450+
approx_start,
451+
approx_end,
452+
mode,
453+
column_name_of_corrupt_record,
439454
):
440455
yield (element,)

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,13 +864,15 @@ def xml(self, path: str, _emit_ast: bool = True) -> DataFrame:
864864
- When ``rowTag`` is specified, the following options are supported for reading XML files
865865
via :meth:`option()` or :meth:`options()`:
866866
867-
+ ``mode``: Specifies the mode of for dealing with corrupt XML records. The default value is ``PERMISSIVE``. The supported values are:
867+
+ ``mode``: Specifies the mode for dealing with corrupt XML records. The default value is ``PERMISSIVE``. The supported values are:
868868
869869
- ``PERMISSIVE``: When it encounters a corrupt record, it sets all fields to null and includes a 'columnNameOfCorruptRecord' column.
870870
871871
- ``DROPMALFORMED``: Ignores the whole record that cannot be parsed correctly.
872872
873873
- ``FAILFAST``: When it encounters a corrupt record, it raises an exception immediately.
874+
+ ``columnNameOfCorruptRecord``: Specifies the name of the column that contains the corrupt record.
875+
The default value is '_corrupt_record'.
874876
"""
875877
df = self._read_semi_structured_file(path, "XML")
876878

tests/integ/scala/test_dataframe_reader_suite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,10 +2069,10 @@ def test_read_malformed_xml(session, file):
20692069
)
20702070
result = df.collect()
20712071
assert len(result) == 2
2072-
assert len(result[0]) == 4 # has another column 'columnNameOfCorruptRecord'
2072+
assert len(result[0]) == 4 # has another column '_corrupt_record'
20732073
assert (
2074-
result[0]["'columnNameOfCorruptRecord'"] is not None
2075-
or result[1]["'columnNameOfCorruptRecord'"] is not None
2074+
result[0]["'_corrupt_record'"] is not None
2075+
or result[1]["'_corrupt_record'"] is not None
20762076
)
20772077

20782078
# dropmalformed mode

0 commit comments

Comments
 (0)