Skip to content

Commit b19c5ff

Browse files
authored
SNOW-1980102: Support reading XML file with row tag (#3185)
1 parent 18cd1c0 commit b19c5ff

File tree

14 files changed

+1130
-0
lines changed

14 files changed

+1130
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ docs/source/modin/pandas_api/
143143
.idea/
144144
.vscode/
145145
*.code-workspace
146+
.run/
146147

147148
# performance test result
148149
tests/perf/results/

src/snowflake/snowpark/_internal/analyzer/analyzer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,7 @@ def do_resolve_with_resolved_children(
13371337
metadata_project=logical_plan.metadata_project,
13381338
metadata_schema=logical_plan.metadata_schema,
13391339
use_user_schema=logical_plan.use_user_schema,
1340+
xml_reader_udtf=logical_plan.xml_reader_udtf,
13401341
source_plan=logical_plan,
13411342
)
13421343

src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
) # pragma: no cover
3838
import snowflake.snowpark.session
3939
import snowflake.snowpark.dataframe
40+
from snowflake.snowpark.udtf import UserDefinedTableFunction
4041

4142
import snowflake.connector
4243
import snowflake.snowpark
@@ -108,6 +109,8 @@
108109
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
109110
from snowflake.snowpark._internal.utils import (
110111
INFER_SCHEMA_FORMAT_TYPES,
112+
XML_ROW_TAG_STRING,
113+
XML_ROW_DATA_COLUMN_NAME,
111114
TempObjectType,
112115
generate_random_alphanumeric,
113116
get_copy_into_table_options,
@@ -1285,6 +1288,60 @@ def process_list(list_property):
12851288
)(setting["property_value"])
12861289
return new_options
12871290

1291+
def _create_xml_query(
1292+
self,
1293+
xml_reader_udtf: "UserDefinedTableFunction",
1294+
file_path: str,
1295+
options: Dict[str, str],
1296+
) -> str:
1297+
"""
1298+
Creates a DataFrame from a UserDefinedTableFunction that reads XML files.
1299+
"""
1300+
from snowflake.snowpark.functions import lit, col, seq8, flatten
1301+
from snowflake.snowpark._internal.xml_reader import DEFAULT_CHUNK_SIZE
1302+
1303+
worker_column_name = "WORKER"
1304+
xml_row_number_column_name = "XML_ROW_NUMBER"
1305+
row_tag = options[XML_ROW_TAG_STRING]
1306+
1307+
# TODO SNOW-1983360: make it an configurable option once the UDTF scalability issue is resolved.
1308+
# Currently it's capped at 16.
1309+
file_size = int(self.session.sql(f"ls {file_path}", _emit_ast=False).collect()[0]["size"]) # type: ignore
1310+
num_workers = min(16, file_size // DEFAULT_CHUNK_SIZE + 1)
1311+
1312+
# Create a range from 0 to N-1
1313+
df = self.session.range(num_workers).to_df(worker_column_name)
1314+
1315+
# Apply UDTF to the XML file and get each XML record as a Variant data,
1316+
# and append a unique row number to each record.
1317+
df = df.select(
1318+
worker_column_name,
1319+
seq8().as_(xml_row_number_column_name),
1320+
xml_reader_udtf(
1321+
lit(file_path),
1322+
lit(num_workers),
1323+
lit(row_tag),
1324+
col(worker_column_name),
1325+
),
1326+
)
1327+
1328+
# Flatten the Variant data to get the key-value pairs
1329+
df = df.select(
1330+
worker_column_name,
1331+
xml_row_number_column_name,
1332+
flatten(XML_ROW_DATA_COLUMN_NAME),
1333+
).select(worker_column_name, xml_row_number_column_name, "key", "value")
1334+
1335+
# Apply dynamic pivot to get the flat table with dynamic schema
1336+
df = (
1337+
df.pivot("key")
1338+
.max("value")
1339+
.sort(worker_column_name, xml_row_number_column_name)
1340+
)
1341+
1342+
# Exclude the worker and row number columns
1343+
return f"SELECT * EXCLUDE ({worker_column_name}, {xml_row_number_column_name}) FROM ({df.queries['queries'][-1]})"
1344+
12881345
def read_file(
12891346
self,
12901347
path: str,
@@ -1296,9 +1353,23 @@ def read_file(
12961353
metadata_project: Optional[List[str]] = None,
12971354
metadata_schema: Optional[List[Attribute]] = None,
12981355
use_user_schema: bool = False,
1356+
xml_reader_udtf: Optional["UserDefinedTableFunction"] = None,
12991357
source_plan: Optional[ReadFileNode] = None,
13001358
) -> SnowflakePlan:
13011359
thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled
1360+
1361+
if xml_reader_udtf is not None:
1362+
xml_query = self._create_xml_query(xml_reader_udtf, path, options)
1363+
return SnowflakePlan(
1364+
[Query(xml_query)],
1365+
# the schema query of dynamic pivot must be the same as the original query
1366+
xml_query,
1367+
None,
1368+
{},
1369+
source_plan=source_plan,
1370+
session=self.session,
1371+
)
1372+
13021373
format_type_options, copy_options = get_copy_into_table_options(options)
13031374
format_type_options = self._merge_file_format_options(
13041375
format_type_options, options

src/snowflake/snowpark/_internal/analyzer/snowflake_plan_node.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
if TYPE_CHECKING:
2727
from snowflake.snowpark import Session
28+
from snowflake.snowpark.udtf import UserDefinedTableFunction
2829

2930

3031
class LogicalPlan:
@@ -317,6 +318,7 @@ def __init__(
317318
metadata_project: Optional[List[str]] = None,
318319
metadata_schema: Optional[List[Attribute]] = None,
319320
use_user_schema: bool = False,
321+
xml_reader_udtf: Optional["UserDefinedTableFunction"] = None,
320322
) -> None:
321323
super().__init__()
322324
self.path = path
@@ -328,6 +330,7 @@ def __init__(
328330
self.metadata_project = metadata_project
329331
self.metadata_schema = metadata_schema
330332
self.use_user_schema = use_user_schema
333+
self.xml_reader_udtf = xml_reader_udtf
331334

332335
@classmethod
333336
def from_read_file_node(cls, read_file_node: "ReadFileNode"):

src/snowflake/snowpark/_internal/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@
197197
"COPY_OPTIONS",
198198
}
199199

200+
XML_ROW_TAG_STRING = "ROWTAG"
201+
XML_ROW_DATA_COLUMN_NAME = "ROW_DATA"
202+
XML_READER_FILE_PATH = os.path.join(os.path.dirname(__file__), "xml_reader.py")
203+
200204
QUERY_TAG_STRING = "QUERY_TAG"
201205
SKIP_LEVELS_TWO = (
202206
2 # limit traceback to return up to 2 stack trace entries from traceback object tb

0 commit comments

Comments
 (0)