Skip to content

Commit 98330fa

Browse files
authored
SNOW-1843881: Change StructType columns to return Row objects (#2820)
1 parent 7c1ed3e commit 98330fa

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

src/snowflake/snowpark/_internal/utils.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@
4545
)
4646

4747
import snowflake.snowpark
48+
from snowflake.connector.constants import FIELD_ID_TO_NAME
4849
from snowflake.connector.cursor import ResultMetadata, SnowflakeCursor
4950
from snowflake.connector.description import OPERATING_SYSTEM, PLATFORM
5051
from snowflake.connector.options import MissingOptionalDependency, ModuleLikeObject
5152
from snowflake.connector.version import VERSION as connector_version
5253
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
54+
from snowflake.snowpark.context import _should_use_structured_type_semantics
5355
from snowflake.snowpark.row import Row
5456
from snowflake.snowpark.version import VERSION as snowpark_version
5557

@@ -698,19 +700,50 @@ def column_to_bool(col_):
698700
return bool(col_)
699701

700702

703+
def _parse_result_meta(
704+
result_meta: Union[List[ResultMetadata], List["ResultMetadataV2"]]
705+
) -> Tuple[Optional[List[str]], Optional[List[Callable]]]:
706+
"""
707+
Takes a list of result metadata objects and returns a list containing the names of all fields as
708+
well as a list of functions that wrap specific columns.
709+
710+
A column type may need to be wrapped if the connector is unable to provide the columns data in
711+
an expected format. For example StructType columns are returned as dict objects, but are better
712+
represented as Row objects.
713+
"""
714+
if not result_meta:
715+
return None, None
716+
col_names = []
717+
wrappers = []
718+
for col in result_meta:
719+
col_names.append(col.name)
720+
if (
721+
_should_use_structured_type_semantics()
722+
and FIELD_ID_TO_NAME[col.type_code] == "OBJECT"
723+
and col.fields is not None
724+
):
725+
wrappers.append(lambda x: Row(**x))
726+
else:
727+
wrappers.append(None)
728+
return col_names, wrappers
729+
730+
701731
def result_set_to_rows(
702732
result_set: List[Any],
703733
result_meta: Optional[Union[List[ResultMetadata], List["ResultMetadataV2"]]] = None,
704734
case_sensitive: bool = True,
705735
) -> List[Row]:
706-
col_names = [col.name for col in result_meta] if result_meta else None
736+
col_names, wrappers = _parse_result_meta(result_meta or [])
707737
rows = []
708738
row_struct = Row
709739
if col_names:
710740
row_struct = (
711741
Row._builder.build(*col_names).set_case_sensitive(case_sensitive).to_row()
712742
)
713743
for data in result_set:
744+
if wrappers:
745+
data = [wrap(d) if wrap else d for wrap, d in zip(wrappers, data)]
746+
714747
if data is None:
715748
raise ValueError("Result returned from Python connector is None")
716749
row = row_struct(*data)
@@ -723,7 +756,7 @@ def result_set_to_iter(
723756
result_meta: Optional[List[ResultMetadata]] = None,
724757
case_sensitive: bool = True,
725758
) -> Iterator[Row]:
726-
col_names = [col.name for col in result_meta] if result_meta else None
759+
col_names, wrappers = _parse_result_meta(result_meta)
727760
row_struct = Row
728761
if col_names:
729762
row_struct = (
@@ -732,6 +765,8 @@ def result_set_to_iter(
732765
for data in result_set:
733766
if data is None:
734767
raise ValueError("Result returned from Python connector is None")
768+
if wrappers:
769+
data = [wrap(d) if wrap else d for wrap, d in zip(wrappers, data)]
735770
row = row_struct(*data)
736771
yield row
737772

tests/integ/scala/test_datatype_suite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -876,8 +876,8 @@ def test_structured_dtypes_iceberg_create_from_values(
876876
_, __, expected_schema = _create_example(True)
877877
table_name = f"snowpark_structured_dtypes_{uuid.uuid4().hex[:5]}"
878878
data = [
879-
({"x": 1}, {"A": "a", "b": 1}, [1, 1, 1]),
880-
({"x": 2}, {"A": "b", "b": 2}, [2, 2, 2]),
879+
({"x": 1}, Row(A="a", b=1), [1, 1, 1]),
880+
({"x": 2}, Row(A="b", b=2), [2, 2, 2]),
881881
]
882882
try:
883883
create_df = structured_type_session.create_dataframe(
@@ -1043,7 +1043,7 @@ def test_structured_dtypes_cast(structured_type_session, structured_type_support
10431043
)
10441044
assert cast_df.schema == expected_structured_schema
10451045
assert cast_df.collect() == [
1046-
Row([1, 2, 3], {"k1": 1, "k2": 2}, {"A": 1.0, "B": "foobar"})
1046+
Row([1, 2, 3], {"k1": 1, "k2": 2}, Row(A=1.0, B="foobar"))
10471047
]
10481048

10491049

0 commit comments

Comments
 (0)