66import re
77import html .entities
88import struct
9+ import copy
910from typing import Optional , Dict , Any , Iterator , BinaryIO , Union , Tuple
1011
12+ from snowflake .snowpark ._internal .analyzer .analyzer_utils import unquote_if_quoted
1113from snowflake .snowpark ._internal .type_utils import type_string_to_type_object
1214from snowflake .snowpark .files import SnowflakeFile
13- from snowflake .snowpark .types import StructType
15+ from snowflake .snowpark .types import StructType , ArrayType , DataType , MapType
1416
1517# lxml is only a dev dependency so use try/except to import it if available
1618try :
@@ -55,6 +57,36 @@ def replace_entity(match: re.Match) -> str:
5557 return match .group (0 )
5658
5759
60+ def schema_string_to_result_dict_and_struct_type (schema_string : str ) -> Optional [dict ]:
61+ if schema_string == "" :
62+ return None
63+ schema = type_string_to_type_object (schema_string )
64+ if not isinstance (schema , StructType ):
65+ return None
66+
67+ return struct_type_to_result_template (schema )
68+
69+
70+ def struct_type_to_result_template (dt : DataType ) -> Optional [dict ]:
71+ if isinstance (dt , StructType ):
72+ out : Dict [str , Any ] = {}
73+ for f in dt .fields :
74+ out [unquote_if_quoted (f .name )] = struct_type_to_result_template (f .datatype )
75+ return out
76+
77+ if isinstance (dt , ArrayType ) and dt .element_type is not None :
78+ return struct_type_to_result_template (dt .element_type )
79+
80+ if isinstance (dt , MapType ) and dt .value_type is not None :
81+ return struct_type_to_result_template (dt .value_type )
82+
83+ return None
84+
85+
86+ def generate_norm_column_name_to_ori_column_name_dict (result : dict ):
87+ return {key .lower (): key for key in result .keys ()}
88+
89+
5890def get_file_size (filename : str ) -> Optional [int ]:
5991 """
6092 Get the size of a file using a file object without reading its content.
@@ -273,10 +305,16 @@ def element_to_dict_or_str(
273305 value_tag : str = "_VALUE" ,
274306 null_value : str = "" ,
275307 ignore_surrounding_whitespace : bool = False ,
308+ result_template : Optional [dict ] = None ,
276309) -> Optional [Union [Dict [str , Any ], str ]]:
277310 """
278311 Recursively converts an XML Element to a dictionary.
279312 """
313+ norm_name_to_ori_name = (
314+ generate_norm_column_name_to_ori_column_name_dict (result_template )
315+ if result_template is not None
316+ else None
317+ )
280318
281319 def get_text (element : ET .Element ) -> Optional [str ]:
282320 """Do not strip the text"""
@@ -292,28 +330,43 @@ def get_text(element: ET.Element) -> Optional[str]:
292330 # it's a value element with no attributes or excluded attributes, so return the text
293331 return get_text (element )
294332
295- result = {}
333+ result = copy . deepcopy ( result_template ) if result_template is not None else {}
296334
297335 if not exclude_attributes :
298336 for attr_name , attr_value in element .attrib .items ():
299337 if ignore_surrounding_whitespace :
300338 attr_value = attr_value .strip ()
301- result [f"{ attribute_prefix } { attr_name } " ] = (
302- None if attr_value == null_value else attr_value
303- )
339+ attribute_name = f"{ attribute_prefix } { attr_name } "
340+ # when custom_schema exists, only exact mathc is allowed
341+ if result_template is None :
342+ result [attribute_name ] = (
343+ None if attr_value == null_value else attr_value
344+ )
345+ elif attribute_name .lower () in norm_name_to_ori_name :
346+ result [norm_name_to_ori_name [attribute_name .lower ()]] = (
347+ None if attr_value == null_value else attr_value
348+ )
304349
305350 if children :
306351 temp_dict = {}
307352 for child in children :
353+ tag = child .tag
354+ child_result_template = None
355+ if result_template is not None :
356+ # skip if not in custom schema
357+ if tag .lower () not in norm_name_to_ori_name :
358+ continue
359+ tag = norm_name_to_ori_name [tag .lower ()]
360+ child_result_template = result_template [tag ]
308361 child_dict = element_to_dict_or_str (
309362 child ,
310363 attribute_prefix = attribute_prefix ,
311364 exclude_attributes = exclude_attributes ,
312365 value_tag = value_tag ,
313366 null_value = null_value ,
314367 ignore_surrounding_whitespace = ignore_surrounding_whitespace ,
368+ result_template = child_result_template ,
315369 )
316- tag = child .tag
317370 if tag in temp_dict :
318371 if not isinstance (temp_dict [tag ], list ):
319372 temp_dict [tag ] = [temp_dict [tag ]]
@@ -345,7 +398,7 @@ def process_xml_range(
345398 ignore_surrounding_whitespace : bool ,
346399 row_validation_xsd_path : str ,
347400 chunk_size : int = DEFAULT_CHUNK_SIZE ,
348- custom_schema : Optional [StructType ] = None ,
401+ result_template : Optional [dict ] = None ,
349402) -> Iterator [Optional [Dict [str , Any ]]]:
350403 """
351404 Processes an XML file within a given approximate byte range.
@@ -375,7 +428,7 @@ def process_xml_range(
375428 ignore_surrounding_whitespace (bool): Whether or not whitespaces surrounding values should be skipped.
376429 row_validation_xsd_path (str): Path to XSD file for row validation.
377430 chunk_size (int): Size of chunks to read.
378- custom_schema(StructType ): User input schema for xml, must be used together with row tag.
431+ result_template(dict ): a result template generate from user input schema
379432
380433 Yields:
381434 Optional[Dict[str, Any]]: Dictionary representation of the parsed XML element.
@@ -506,6 +559,7 @@ def process_xml_range(
506559 value_tag = value_tag ,
507560 null_value = null_value ,
508561 ignore_surrounding_whitespace = ignore_surrounding_whitespace ,
562+ result_template = copy .deepcopy (result_template ),
509563 )
510564 if isinstance (result , dict ):
511565 yield result
@@ -573,7 +627,7 @@ def process(
573627 approx_chunk_size = file_size // num_workers
574628 approx_start = approx_chunk_size * i
575629 approx_end = approx_chunk_size * (i + 1 ) if i < num_workers - 1 else file_size
576- custom_schema = type_string_to_type_object (custom_schema )
630+ result_template = schema_string_to_result_dict_and_struct_type (custom_schema )
577631 for element in process_xml_range (
578632 filename ,
579633 row_tag ,
@@ -589,6 +643,6 @@ def process(
589643 charset ,
590644 ignore_surrounding_whitespace ,
591645 row_validation_xsd_path = row_validation_xsd_path ,
592- custom_schema = custom_schema ,
646+ result_template = result_template ,
593647 ):
594648 yield (element ,)
0 commit comments