Skip to content

Commit 7e70380

Browse files
authored
Merge branch 'main' into SNOW-1955847-feat-add-support-to-postgresql
2 parents 8aa988e + e93727f commit 7e70380

File tree

4 files changed

+86
-10
lines changed

4 files changed

+86
-10
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
- Invoking snowflake system procedures does not invoke an additional `describe procedure` call to check the return type of the procedure.
1414
- Added support for `Session.create_dataframe()` with the stage URL and FILE data type.
15-
- Added support for different modes for dealing with corrupt XML records when reading an XML file using `session.read.option('rowTag', <tag_name>).xml(<stage_file_path>)`. Currently `PERMISSIVE`, `DROPMALFORMED` and `FAILFAST` are supported.
15+
- Added support for different modes for dealing with corrupt XML records when reading an XML file using `session.read.option('mode', <mode>), option('rowTag', <tag_name>).xml(<stage_file_path>)`. Currently `PERMISSIVE`, `DROPMALFORMED` and `FAILFAST` are supported.
16+
- Improved the error message of the XML reader when the specified row tag is not found in the file.
1617
- Improved query generation for `Dataframe.drop` to use `SELECT * EXCLUDE ()` to exclude the dropped columns. To enable this feature, set `session.conf.set("use_simplified_query_generation", True)`.
1718

1819
#### Bug Fixes

src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ def wrap(*args, **kwargs):
150150
try:
151151
return func(*args, **kwargs)
152152
except snowflake.connector.errors.ProgrammingError as e:
153+
from snowflake.snowpark._internal.analyzer.select_statement import (
154+
Selectable,
155+
)
156+
153157
query = getattr(e, "query", None)
154158
tb = sys.exc_info()[2]
155159
assert e.msg is not None
@@ -209,10 +213,6 @@ def wrap(*args, **kwargs):
209213
)
210214
raise ne.with_traceback(tb) from None
211215
else:
212-
from snowflake.snowpark._internal.analyzer.select_statement import (
213-
Selectable,
214-
)
215-
216216
# We need the potential double quotes for invalid identifier
217217
match = SnowflakePlan.Decorator.__wrap_exception_regex_match_with_double_quotes.match(
218218
e.msg
@@ -277,11 +277,53 @@ def add_single_quote(string: str) -> str:
277277
e
278278
)
279279
raise ne.with_traceback(tb) from None
280-
else:
281-
ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR(
282-
e
283-
)
284-
raise ne.with_traceback(tb) from None
280+
elif e.sqlstate == "42601" and "SELECT with no columns" in e.msg:
281+
# This is a special case when the select statement has no columns,
282+
# and it's a reading XML query.
283+
284+
def search_read_file_node(
285+
node: Union[SnowflakePlan, Selectable]
286+
) -> Optional[ReadFileNode]:
287+
for child in node.children_plan_nodes:
288+
source_plan = (
289+
child.source_plan
290+
if isinstance(child, SnowflakePlan)
291+
else child.snowflake_plan.source_plan
292+
)
293+
if isinstance(source_plan, ReadFileNode):
294+
return source_plan
295+
result = search_read_file_node(child)
296+
if result:
297+
return result
298+
return None
299+
300+
for arg in args:
301+
if isinstance(arg, SnowflakePlan):
302+
read_file_node = search_read_file_node(arg)
303+
if (
304+
read_file_node
305+
and read_file_node.xml_reader_udtf is not None
306+
):
307+
row_tag = read_file_node.options.get(
308+
XML_ROW_TAG_STRING
309+
)
310+
file_path = read_file_node.path
311+
ne = SnowparkClientExceptionMessages.DF_XML_ROW_TAG_NOT_FOUND(
312+
row_tag, file_path
313+
)
314+
raise ne.with_traceback(tb) from None
315+
# when the describe query fails, the arg is a query string
316+
elif isinstance(arg, str):
317+
if f'"{XML_ROW_DATA_COLUMN_NAME}"' in arg:
318+
ne = (
319+
SnowparkClientExceptionMessages.DF_XML_ROW_TAG_NOT_FOUND()
320+
)
321+
raise ne.with_traceback(tb) from None
322+
323+
ne = SnowparkClientExceptionMessages.SQL_EXCEPTION_FROM_PROGRAMMING_ERROR(
324+
e
325+
)
326+
raise ne.with_traceback(tb) from None
285327

286328
return wrap
287329

src/snowflake/snowpark/_internal/error_message.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,17 @@ def DF_COPY_INTO_CANNOT_CREATE_TABLE(
127127
f"Cannot create the target table {table_name} because Snowpark cannot determine the column names to use. You should create the table before calling copy_into_table()."
128128
)
129129

130+
@staticmethod
131+
def DF_XML_ROW_TAG_NOT_FOUND(
132+
row_tag: Optional[str] = None,
133+
file_path: Optional[str] = None,
134+
) -> SnowparkDataframeReaderException:
135+
if row_tag is not None and file_path is not None:
136+
msg = f"Cannot find the row tag '{row_tag}' in the XML file {file_path}."
137+
else:
138+
msg = "Cannot find the row tag in the XML file."
139+
return SnowparkDataframeReaderException(msg)
140+
130141
@staticmethod
131142
def DF_CROSS_TAB_COUNT_TOO_LARGE(
132143
count: int, max_count: int

tests/integ/scala/test_dataframe_reader_suite.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,3 +2093,25 @@ def test_read_malformed_xml(session, file):
20932093
)
20942094
with pytest.raises(SnowparkSQLException, match="Malformed XML record at bytes"):
20952095
df.collect()
2096+
2097+
2098+
@pytest.mark.skipif(
2099+
"config.getoption('local_testing_mode', default=False)",
2100+
reason="xml not supported in local testing mode",
2101+
)
2102+
def test_read_xml_row_tag_not_found(session):
2103+
row_tag = "non-existing-tag"
2104+
df = session.read.option("rowTag", row_tag).xml(
2105+
f"@{tmp_stage_name1}/{test_file_books_xml}"
2106+
)
2107+
2108+
with pytest.raises(
2109+
SnowparkDataframeReaderException, match="Cannot find the row tag"
2110+
):
2111+
df.collect()
2112+
2113+
# also works for nested query plan
2114+
with pytest.raises(
2115+
SnowparkDataframeReaderException, match="Cannot find the row tag"
2116+
):
2117+
df.filter(lit(True)).collect()

0 commit comments

Comments
 (0)