3737 ) # pragma: no cover
3838 import snowflake .snowpark .session
3939 import snowflake .snowpark .dataframe
40+ from snowflake .snowpark .udtf import UserDefinedTableFunction
4041
4142import snowflake .connector
4243import snowflake .snowpark
108109from snowflake .snowpark ._internal .error_message import SnowparkClientExceptionMessages
109110from snowflake .snowpark ._internal .utils import (
110111 INFER_SCHEMA_FORMAT_TYPES ,
112+ XML_ROW_TAG_STRING ,
113+ XML_ROW_DATA_COLUMN_NAME ,
111114 TempObjectType ,
112115 generate_random_alphanumeric ,
113116 get_copy_into_table_options ,
@@ -1285,6 +1288,60 @@ def process_list(list_property):
12851288 )(setting ["property_value" ])
12861289 return new_options
12871290
1291+ def _create_xml_query (
1292+ self ,
1293+ xml_reader_udtf : "UserDefinedTableFunction" ,
1294+ file_path : str ,
1295+ options : Dict [str , str ],
1296+ ) -> str :
1297+ """
1298+ Creates a DataFrame from a UserDefinedTableFunction that reads XML files.
1299+ """
1300+ from snowflake .snowpark .functions import lit , col , seq8 , flatten
1301+ from snowflake .snowpark ._internal .xml_reader import DEFAULT_CHUNK_SIZE
1302+
1303+ worker_column_name = "WORKER"
1304+ xml_row_number_column_name = "XML_ROW_NUMBER"
1305+ row_tag = options [XML_ROW_TAG_STRING ]
1306+
1307+ # TODO SNOW-1983360: make it an configurable option once the UDTF scalability issue is resolved.
1308+ # Currently it's capped at 16.
1309+ file_size = int (self .session .sql (f"ls { file_path } " , _emit_ast = False ).collect ()[0 ]["size" ]) # type: ignore
1310+ num_workers = min (16 , file_size // DEFAULT_CHUNK_SIZE + 1 )
1311+
1312+ # Create a range from 0 to N-1
1313+ df = self .session .range (num_workers ).to_df (worker_column_name )
1314+
1315+ # Apply UDTF to the XML file and get each XML record as a Variant data,
1316+ # and append a unique row number to each record.
1317+ df = df .select (
1318+ worker_column_name ,
1319+ seq8 ().as_ (xml_row_number_column_name ),
1320+ xml_reader_udtf (
1321+ lit (file_path ),
1322+ lit (num_workers ),
1323+ lit (row_tag ),
1324+ col (worker_column_name ),
1325+ ),
1326+ )
1327+
1328+ # Flatten the Variant data to get the key-value pairs
1329+ df = df .select (
1330+ worker_column_name ,
1331+ xml_row_number_column_name ,
1332+ flatten (XML_ROW_DATA_COLUMN_NAME ),
1333+ ).select (worker_column_name , xml_row_number_column_name , "key" , "value" )
1334+
1335+ # Apply dynamic pivot to get the flat table with dynamic schema
1336+ df = (
1337+ df .pivot ("key" )
1338+ .max ("value" )
1339+ .sort (worker_column_name , xml_row_number_column_name )
1340+ )
1341+
1342+ # Exclude the worker and row number columns
1343+ return f"SELECT * EXCLUDE ({ worker_column_name } , { xml_row_number_column_name } ) FROM ({ df .queries ['queries' ][- 1 ]} )"
1344+
12881345 def read_file (
12891346 self ,
12901347 path : str ,
@@ -1296,9 +1353,23 @@ def read_file(
12961353 metadata_project : Optional [List [str ]] = None ,
12971354 metadata_schema : Optional [List [Attribute ]] = None ,
12981355 use_user_schema : bool = False ,
1356+ xml_reader_udtf : Optional ["UserDefinedTableFunction" ] = None ,
12991357 source_plan : Optional [ReadFileNode ] = None ,
13001358 ) -> SnowflakePlan :
13011359 thread_safe_session_enabled = self .session ._conn ._thread_safe_session_enabled
1360+
1361+ if xml_reader_udtf is not None :
1362+ xml_query = self ._create_xml_query (xml_reader_udtf , path , options )
1363+ return SnowflakePlan (
1364+ [Query (xml_query )],
1365+ # the schema query of dynamic pivot must be the same as the original query
1366+ xml_query ,
1367+ None ,
1368+ {},
1369+ source_plan = source_plan ,
1370+ session = self .session ,
1371+ )
1372+
13021373 format_type_options , copy_options = get_copy_into_table_options (options )
13031374 format_type_options = self ._merge_file_format_options (
13041375 format_type_options , options
0 commit comments