Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@
convert_sf_to_sp_type,
convert_sp_to_sf_type,
)
from snowflake.snowpark._internal.udf_utils import get_types_from_type_hints
from snowflake.snowpark._internal.utils import (
STAGE_PREFIX,
XML_ROW_TAG_STRING,
XML_ROW_DATA_COLUMN_NAME,
XML_READER_FILE_PATH,
XML_READER_API_SIGNATURE,
XML_READER_SQL_COMMENT,
INFER_SCHEMA_FORMAT_TYPES,
SNOWFLAKE_PATH_PREFIXES,
TempObjectType,
Expand All @@ -70,6 +73,7 @@
private_preview,
random_name_for_temp_object,
warning,
is_in_stored_procedure,
)
from snowflake.snowpark.column import METADATA_COLUMN_TYPES, Column, _to_col_if_str
from snowflake.snowpark.dataframe import DataFrame
Expand Down Expand Up @@ -1106,13 +1110,40 @@ def _read_semi_structured_file(self, path: str, format: str) -> DataFrame:
"rowTag",
"rowTag for reading XML file is in private preview since 1.31.0. Do not use it in production.",
)

if is_in_stored_procedure(): # pragma: no cover
# create a temp stage for udtf import files
# we have to use "temp" object instead of "scoped temp" object in stored procedure
# so we need to upload the file to the temp stage first to use register_from_file
temp_stage = random_name_for_temp_object(TempObjectType.STAGE)
sql_create_temp_stage = f"create temp stage if not exists {temp_stage} {XML_READER_SQL_COMMENT}"
self._session.sql(sql_create_temp_stage, _emit_ast=False).collect(
_emit_ast=False
)
self._session._conn.upload_file(
XML_READER_FILE_PATH,
temp_stage,
compress_data=False,
overwrite=True,
skip_upload_on_content_match=True,
)
python_file_path = f"{STAGE_PREFIX}{temp_stage}/{os.path.basename(XML_READER_FILE_PATH)}"
else:
python_file_path = XML_READER_FILE_PATH

# create udtf
handler_name = "XMLReader"
_, input_types = get_types_from_type_hints(
(XML_READER_FILE_PATH, handler_name), TempObjectType.TABLE_FUNCTION
Comment on lines +1136 to +1137
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There appears to be a path inconsistency in the type hint resolution. The code is using XML_READER_FILE_PATH to extract type hints, but then using python_file_path when calling register_from_file(). In the stored procedure case, these paths differ significantly - XML_READER_FILE_PATH is a local filesystem path while python_file_path is a stage path.

For consistency and correctness, the same path should be used in both places. Consider modifying the code to use python_file_path for both the type hint extraction and the registration.

Suggested change
_, input_types = get_types_from_type_hints(
(XML_READER_FILE_PATH, handler_name), TempObjectType.TABLE_FUNCTION
_, input_types = get_types_from_type_hints(
(python_file_path, handler_name), TempObjectType.TABLE_FUNCTION

Spotted by Diamond

Is this helpful? React 👍 or 👎 to let us know.

)
output_schema = StructType(
[StructField(XML_ROW_DATA_COLUMN_NAME, VariantType(), True)]
)
xml_reader_udtf = self._session.udtf.register_from_file(
XML_READER_FILE_PATH,
"XMLReader",
python_file_path,
handler_name,
output_schema=output_schema,
input_types=input_types,
packages=["snowflake-snowpark-python"],
replace=True,
)
Expand Down
3 changes: 0 additions & 3 deletions tests/integ/scala/test_dataframe_reader_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ def get_df_from_reader_and_file_format(reader, file_format):

@pytest.fixture(scope="module", autouse=True)
def setup(session, resources_path, local_testing_mode):
# TODO SNOW-2098847: remove this workaround after fixing the issue
session._use_scoped_temp_objects = False

test_files = TestFiles(resources_path)
if not local_testing_mode:
Utils.create_stage(session, tmp_stage_name1, is_temporary=True)
Expand Down
Loading