diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 0910a2a4aa..a989e1625f 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -36,6 +36,7 @@ from snowflake.connector.cursor import ResultMetadata from snowflake.connector.options import installed_pandas, pandas from snowflake.snowpark._internal.utils import quote_name +from snowflake.snowpark.row import Row from snowflake.snowpark.types import ( LTZ, NTZ, @@ -441,6 +442,8 @@ def infer_type(obj: Any) -> DataType: if key is not None and value is not None: return MapType(infer_type(key), infer_type(value)) return MapType(NullType(), NullType()) + elif isinstance(obj, Row) and context._should_use_structured_type_semantics(): + return infer_schema(obj) elif isinstance(obj, (list, tuple)): for v in obj: if v is not None: diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 88d2a4a32d..ba5cfdfbd0 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -36,6 +36,7 @@ import pkg_resources import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto +import snowflake.snowpark.context as context from snowflake.connector import ProgrammingError, SnowflakeConnection from snowflake.connector.options import installed_pandas, pandas from snowflake.connector.pandas_tools import write_pandas @@ -3306,6 +3307,14 @@ def convert_row_to_list( data_type, (MapType, StructType) ): converted_row.append(json.dumps(value, cls=PythonObjJSONEncoder)) + elif ( + isinstance(value, Row) + and isinstance(data_type, StructType) + and context._should_use_structured_type_semantics() + ): + converted_row.append( + json.dumps(value.as_dict(), cls=PythonObjJSONEncoder) + ) elif isinstance(data_type, VariantType): converted_row.append(json.dumps(value, cls=PythonObjJSONEncoder)) elif isinstance(data_type, GeographyType): diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index 9f08cedbde..935a8c829d 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -598,6 +598,47 @@ def finish(self) -> dict: assert MapCollector._return_type == MapType() +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="local testing does not fully support structured types yet.", +) +def test_structured_type_infer(structured_type_session, structured_type_support): + if not structured_type_support: + pytest.skip("Test requires structured type support.") + + struct = Row(f1="v1", f2=2) + df = structured_type_session.create_dataframe( + [ + ({"key": "value"}, [1, 2, 3], struct), + ], + schema=["map", "array", "obj"], + ) + + assert df.schema == StructType( + [ + StructField( + "MAP", + MapType(StringType(), StringType(), structured=True), + nullable=True, + ), + StructField("ARRAY", ArrayType(LongType(), structured=True), nullable=True), + StructField( + "OBJ", + StructType( + [ + StructField("f1", StringType(), nullable=True), + StructField("f2", LongType(), nullable=True), + ], + structured=True, + ), + nullable=True, + ), + ], + structured=True, + ) + df.collect() + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="local testing does not fully support structured types yet.",