Skip to content

Commit 2bac9f6

Browse files
committed
add logic to user custom schema(minus type mapping)
1 parent c715863 commit 2bac9f6

File tree

4 files changed

+254
-35
lines changed

4 files changed

+254
-35
lines changed

src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1911,7 +1911,7 @@ def read_file(
19111911
schema_string = attribute_to_schema_string(schema)
19121912
if xml_reader_udtf is not None:
19131913
xml_query = self._create_xml_query(
1914-
xml_reader_udtf, path, options, schema_string
1914+
xml_reader_udtf, path, options, schema_string if use_user_schema else ""
19151915
)
19161916
return SnowflakePlan(
19171917
[Query(xml_query)],

src/snowflake/snowpark/_internal/xml_reader.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
import re
77
import html.entities
88
import struct
9+
import copy
910
from typing import Optional, Dict, Any, Iterator, BinaryIO, Union, Tuple
1011

12+
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
1113
from snowflake.snowpark._internal.type_utils import type_string_to_type_object
1214
from snowflake.snowpark.files import SnowflakeFile
13-
from snowflake.snowpark.types import StructType
15+
from snowflake.snowpark.types import StructType, ArrayType, DataType, MapType
1416

1517
# lxml is only a dev dependency so use try/except to import it if available
1618
try:
@@ -55,6 +57,36 @@ def replace_entity(match: re.Match) -> str:
5557
return match.group(0)
5658

5759

60+
def schema_string_to_result_dict_and_struct_type(schema_string: str) -> Optional[dict]:
61+
if schema_string == "":
62+
return None
63+
schema = type_string_to_type_object(schema_string)
64+
if not isinstance(schema, StructType):
65+
return None
66+
67+
return struct_type_to_result_template(schema)
68+
69+
70+
def struct_type_to_result_template(dt: DataType) -> Optional[dict]:
71+
if isinstance(dt, StructType):
72+
out: Dict[str, Any] = {}
73+
for f in dt.fields:
74+
out[unquote_if_quoted(f.name)] = struct_type_to_result_template(f.datatype)
75+
return out
76+
77+
if isinstance(dt, ArrayType) and dt.element_type is not None:
78+
return struct_type_to_result_template(dt.element_type)
79+
80+
if isinstance(dt, MapType) and dt.value_type is not None:
81+
return struct_type_to_result_template(dt.value_type)
82+
83+
return None
84+
85+
86+
def generate_norm_column_name_to_ori_column_name_dict(result: dict):
87+
return {key.lower(): key for key in result.keys()}
88+
89+
5890
def get_file_size(filename: str) -> Optional[int]:
5991
"""
6092
Get the size of a file using a file object without reading its content.
@@ -273,10 +305,16 @@ def element_to_dict_or_str(
273305
value_tag: str = "_VALUE",
274306
null_value: str = "",
275307
ignore_surrounding_whitespace: bool = False,
308+
result_template: Optional[dict] = None,
276309
) -> Optional[Union[Dict[str, Any], str]]:
277310
"""
278311
Recursively converts an XML Element to a dictionary.
279312
"""
313+
norm_name_to_ori_name = (
314+
generate_norm_column_name_to_ori_column_name_dict(result_template)
315+
if result_template is not None
316+
else None
317+
)
280318

281319
def get_text(element: ET.Element) -> Optional[str]:
282320
"""Do not strip the text"""
@@ -292,28 +330,43 @@ def get_text(element: ET.Element) -> Optional[str]:
292330
# it's a value element with no attributes or excluded attributes, so return the text
293331
return get_text(element)
294332

295-
result = {}
333+
result = copy.deepcopy(result_template) if result_template is not None else {}
296334

297335
if not exclude_attributes:
298336
for attr_name, attr_value in element.attrib.items():
299337
if ignore_surrounding_whitespace:
300338
attr_value = attr_value.strip()
301-
result[f"{attribute_prefix}{attr_name}"] = (
302-
None if attr_value == null_value else attr_value
303-
)
339+
attribute_name = f"{attribute_prefix}{attr_name}"
340+
# when custom_schema exists, only exact mathc is allowed
341+
if result_template is None:
342+
result[attribute_name] = (
343+
None if attr_value == null_value else attr_value
344+
)
345+
elif attribute_name.lower() in norm_name_to_ori_name:
346+
result[norm_name_to_ori_name[attribute_name.lower()]] = (
347+
None if attr_value == null_value else attr_value
348+
)
304349

305350
if children:
306351
temp_dict = {}
307352
for child in children:
353+
tag = child.tag
354+
child_result_template = None
355+
if result_template is not None:
356+
# skip if not in custom schema
357+
if tag.lower() not in norm_name_to_ori_name:
358+
continue
359+
tag = norm_name_to_ori_name[tag.lower()]
360+
child_result_template = result_template[tag]
308361
child_dict = element_to_dict_or_str(
309362
child,
310363
attribute_prefix=attribute_prefix,
311364
exclude_attributes=exclude_attributes,
312365
value_tag=value_tag,
313366
null_value=null_value,
314367
ignore_surrounding_whitespace=ignore_surrounding_whitespace,
368+
result_template=child_result_template,
315369
)
316-
tag = child.tag
317370
if tag in temp_dict:
318371
if not isinstance(temp_dict[tag], list):
319372
temp_dict[tag] = [temp_dict[tag]]
@@ -345,7 +398,7 @@ def process_xml_range(
345398
ignore_surrounding_whitespace: bool,
346399
row_validation_xsd_path: str,
347400
chunk_size: int = DEFAULT_CHUNK_SIZE,
348-
custom_schema: Optional[StructType] = None,
401+
result_template: Optional[dict] = None,
349402
) -> Iterator[Optional[Dict[str, Any]]]:
350403
"""
351404
Processes an XML file within a given approximate byte range.
@@ -375,7 +428,7 @@ def process_xml_range(
375428
ignore_surrounding_whitespace (bool): Whether or not whitespaces surrounding values should be skipped.
376429
row_validation_xsd_path (str): Path to XSD file for row validation.
377430
chunk_size (int): Size of chunks to read.
378-
custom_schema(StructType): User input schema for xml, must be used together with row tag.
431+
result_template(dict): a result template generate from user input schema
379432
380433
Yields:
381434
Optional[Dict[str, Any]]: Dictionary representation of the parsed XML element.
@@ -506,6 +559,7 @@ def process_xml_range(
506559
value_tag=value_tag,
507560
null_value=null_value,
508561
ignore_surrounding_whitespace=ignore_surrounding_whitespace,
562+
result_template=copy.deepcopy(result_template),
509563
)
510564
if isinstance(result, dict):
511565
yield result
@@ -573,7 +627,7 @@ def process(
573627
approx_chunk_size = file_size // num_workers
574628
approx_start = approx_chunk_size * i
575629
approx_end = approx_chunk_size * (i + 1) if i < num_workers - 1 else file_size
576-
custom_schema = type_string_to_type_object(custom_schema)
630+
result_template = schema_string_to_result_dict_and_struct_type(custom_schema)
577631
for element in process_xml_range(
578632
filename,
579633
row_tag,
@@ -589,6 +643,6 @@ def process(
589643
charset,
590644
ignore_surrounding_whitespace,
591645
row_validation_xsd_path=row_validation_xsd_path,
592-
custom_schema=custom_schema,
646+
result_template=result_template,
593647
):
594648
yield (element,)

tests/integ/test_xml_reader_row_tag.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@
1212
SnowparkSQLException,
1313
)
1414
from snowflake.snowpark.functions import col, lit
15-
from snowflake.snowpark.types import (
16-
StructType,
17-
StructField,
18-
StringType,
19-
IntegerType,
20-
DecimalType,
21-
)
2215
from tests.utils import TestFiles, Utils
2316

2417

@@ -474,20 +467,3 @@ def test_read_xml_row_validation_xsd_path_failfast(session):
474467
session.read.option("rowTag", row_tag).option(
475468
"rowValidationXSDPath", f"@{tmp_stage_name}/{test_file_books_xsd}"
476469
).option("mode", "failfast").xml(f"@{tmp_stage_name}/{test_file_books_xml}")
477-
478-
479-
def test_read_xml_with_custom_schema(session):
480-
user_schema = StructType(
481-
[
482-
StructField("a", StringType()),
483-
StructField("b", IntegerType()),
484-
StructField("c", DecimalType()),
485-
]
486-
)
487-
df = (
488-
session.read.schema(user_schema)
489-
.option("rowTag", "book")
490-
.option("CACHERESULT", False)
491-
.xml(f"@{tmp_stage_name}/{test_file_books_xml}")
492-
)
493-
df.show()

0 commit comments

Comments
 (0)