diff --git a/src/snowflake/snowpark/_internal/ast/batch.py b/src/snowflake/snowpark/_internal/ast/batch.py index 0e1b8582b9..960bdb7f42 100644 --- a/src/snowflake/snowpark/_internal/ast/batch.py +++ b/src/snowflake/snowpark/_internal/ast/batch.py @@ -72,6 +72,10 @@ def reset_id_gen(self) -> None: """Resets the ID generator.""" self._id_gen = itertools.count(start=1) + def reset_callables(self) -> None: + """Resets the callables.""" + self._callables = {} + def assign(self, symbol: Optional[str] = None) -> proto.Assign: """ Creates a new assignment statement. diff --git a/src/snowflake/snowpark/relational_grouped_dataframe.py b/src/snowflake/snowpark/relational_grouped_dataframe.py index aac64d32c5..476fcd8140 100644 --- a/src/snowflake/snowpark/relational_grouped_dataframe.py +++ b/src/snowflake/snowpark/relational_grouped_dataframe.py @@ -423,7 +423,7 @@ def end_partition(self, pdf: pandas.DataFrame) -> pandas.DataFrame: _apply_in_pandas_udtf = self._dataframe._session.udtf.register( _ApplyInPandas, output_schema=output_schema, - _emit_ast=_emit_ast, + _emit_ast=False, **kwargs, ) partition_by = [Column(expr, _emit_ast=False) for expr in self._grouping_exprs] diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index fbf40871fd..d9830b3fe3 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -784,7 +784,7 @@ def json_value(self) -> Dict[str, Any]: def _fill_ast(self, ast: proto.SpDataType) -> None: ast.sp_struct_type.structured = self.structured for field in self.fields: - field._fill_ast(ast.sp_struct_type.fields.add()) + field._fill_ast(ast.sp_struct_type.fields.list.add()) class VariantType(DataType): diff --git a/tests/ast/data/RelationalGroupedDataFrame.test b/tests/ast/data/RelationalGroupedDataFrame.test index d5d689b719..1d018c5a26 100644 --- a/tests/ast/data/RelationalGroupedDataFrame.test +++ b/tests/ast/data/RelationalGroupedDataFrame.test @@ -45,8 +45,6 @@ res8 = df.group_by("a").count() df = session.create_dataframe([("SF", 21.0), ("SF", 17.5), ("SF", 24.0), ("NY", 30.9), ("NY", 33.6)], schema=["location", "temp_c"]) -res10 = udtf("_ApplyInPandas", output_schema=PandasDataFrameType(StringType(), FloatType(), FloatType(), "LOCATION", "TEMP_C", "TEMP_F"), input_types=[StringType(), FloatType()], copy_grants=False, _registered_object_name="\"MOCK_DATABASE\".\"MOCK_SCHEMA\".SNOWPARK_TEMP_TABLE_FUNCTION_xxx") - df.group_by("location").apply_in_pandas(convert, StructType([StructField("location", StringType(), nullable=True), StructField("temp_c", FloatType(), nullable=True), StructField("temp_f", FloatType(), nullable=True)], structured=False), input_types=[StringType(), FloatType()], input_names=["LOCATION", "TEMP_C"]).sort("temp_c", ascending=None).collect() df = session.create_dataframe([(1, "A", 10000, "JAN"), (1, "B", 400, "JAN"), (1, "B", 5000, "FEB")], schema=["empid", "team", "amount", "month"]) @@ -713,84 +711,11 @@ body { } } } -body { - assign { - expr { - udtf { - handler { - name: "_ApplyInPandas" - object_name { - sp_table_name_flat { - name: "\"MOCK_DATABASE\".\"MOCK_SCHEMA\".SNOWPARK_TEMP_TABLE_FUNCTION_xxx" - } - } - } - input_types { - list { - sp_string_type { - length { - } - } - } - list { - sp_float_type: true - } - } - kwargs { - _1: "copy_grants" - _2 { - bool_val { - src { - file: "SRC_POSITION_TEST_MODE" - start_line: 43 - } - } - } - } - output_schema { - udtf_schema__type { - return_type { - sp_pandas_data_frame_type { - col_names: "LOCATION" - col_names: "TEMP_C" - col_names: "TEMP_F" - col_types { - sp_string_type { - length { - } - } - } - col_types { - sp_float_type: true - } - col_types { - sp_float_type: true - } - } - } - } - } - parallel: 4 - src { - file: "SRC_POSITION_TEST_MODE" - start_line: 43 - } - } - } - symbol { - } - uid: 14 - var_id { - bitfield1: 14 - } - } -} body { assign { expr { sp_relational_grouped_dataframe_apply_in_pandas { func { - id: 1 name: "convert" } grouped_df { @@ -867,34 +792,36 @@ body { } output_schema { fields { - column_identifier { - name: "location" - } - data_type { - sp_string_type { - length { + list { + column_identifier { + name: "location" + } + data_type { + sp_string_type { + length { + } } } + nullable: true } - nullable: true - } - fields { - column_identifier { - name: "temp_c" - } - data_type { - sp_float_type: true - } - nullable: true - } - fields { - column_identifier { - name: "temp_f" + list { + column_identifier { + name: "temp_c" + } + data_type { + sp_float_type: true + } + nullable: true } - data_type { - sp_float_type: true + list { + column_identifier { + name: "temp_f" + } + data_type { + sp_float_type: true + } + nullable: true } - nullable: true } } src { @@ -905,9 +832,9 @@ body { } symbol { } - uid: 15 + uid: 14 var_id { - bitfield1: 15 + bitfield1: 14 } } } @@ -936,7 +863,7 @@ body { df { sp_dataframe_ref { id { - bitfield1: 15 + bitfield1: 14 } } } @@ -948,9 +875,9 @@ body { } symbol { } - uid: 16 + uid: 15 var_id { - bitfield1: 16 + bitfield1: 15 } } } @@ -961,7 +888,7 @@ body { block: true case_sensitive: true id { - bitfield1: 16 + bitfield1: 15 } src { file: "SRC_POSITION_TEST_MODE" @@ -971,17 +898,17 @@ body { } symbol { } - uid: 17 + uid: 16 var_id { - bitfield1: 17 + bitfield1: 16 } } } body { eval { - uid: 18 + uid: 17 var_id { - bitfield1: 17 + bitfield1: 16 } } } @@ -1142,9 +1069,9 @@ body { symbol { value: "df" } - uid: 19 + uid: 18 var_id { - bitfield1: 19 + bitfield1: 18 } } } @@ -1167,7 +1094,7 @@ body { df { sp_dataframe_ref { id { - bitfield1: 19 + bitfield1: 18 } } } @@ -1179,9 +1106,9 @@ body { } symbol { } - uid: 20 + uid: 19 var_id { - bitfield1: 20 + bitfield1: 19 } } } @@ -1192,7 +1119,7 @@ body { grouped_df { sp_relational_grouped_dataframe_ref { id { - bitfield1: 20 + bitfield1: 19 } } } @@ -1243,9 +1170,9 @@ body { } symbol { } - uid: 21 + uid: 20 var_id { - bitfield1: 21 + bitfield1: 20 } } } @@ -1269,7 +1196,7 @@ body { grouped_df { sp_relational_grouped_dataframe_ref { id { - bitfield1: 21 + bitfield1: 20 } } } @@ -1281,9 +1208,9 @@ body { } symbol { } - uid: 22 + uid: 21 var_id { - bitfield1: 22 + bitfield1: 21 } } } @@ -1292,23 +1219,23 @@ body { expr { sp_dataframe_show { id { - bitfield1: 22 + bitfield1: 21 } } } symbol { } - uid: 23 + uid: 22 var_id { - bitfield1: 23 + bitfield1: 22 } } } body { eval { - uid: 24 + uid: 23 var_id { - bitfield1: 23 + bitfield1: 22 } } } @@ -1339,7 +1266,7 @@ body { df { sp_dataframe_ref { id { - bitfield1: 19 + bitfield1: 18 } } } @@ -1351,9 +1278,9 @@ body { } symbol { } - uid: 25 + uid: 24 var_id { - bitfield1: 25 + bitfield1: 24 } } } @@ -1364,7 +1291,7 @@ body { grouped_df { sp_relational_grouped_dataframe_ref { id { - bitfield1: 25 + bitfield1: 24 } } } @@ -1385,9 +1312,9 @@ body { } symbol { } - uid: 26 + uid: 25 var_id { - bitfield1: 26 + bitfield1: 25 } } } @@ -1411,7 +1338,7 @@ body { grouped_df { sp_relational_grouped_dataframe_ref { id { - bitfield1: 26 + bitfield1: 25 } } } @@ -1423,9 +1350,9 @@ body { } symbol { } - uid: 27 + uid: 26 var_id { - bitfield1: 27 + bitfield1: 26 } } } @@ -1463,7 +1390,7 @@ body { df { sp_dataframe_ref { id { - bitfield1: 27 + bitfield1: 26 } } } @@ -1475,9 +1402,9 @@ body { } symbol { } - uid: 28 + uid: 27 var_id { - bitfield1: 28 + bitfield1: 27 } } } @@ -1486,23 +1413,23 @@ body { expr { sp_dataframe_show { id { - bitfield1: 28 + bitfield1: 27 } } } symbol { } - uid: 29 + uid: 28 var_id { - bitfield1: 29 + bitfield1: 28 } } } body { eval { - uid: 30 + uid: 29 var_id { - bitfield1: 29 + bitfield1: 28 } } } diff --git a/tests/ast/data/udtf.test b/tests/ast/data/udtf.test index f5ff5de15c..efa1e0a96d 100644 --- a/tests/ast/data/udtf.test +++ b/tests/ast/data/udtf.test @@ -81,7 +81,7 @@ twox_udtf = udtf( input_types=[IntegerType()], ) -df.select(twox_udtf("a")).collect() +df.select(session.table_function(twox_udtf("a"))[0]).collect() class TwoXSixXUDTF: def process(self, n: int): @@ -95,9 +95,9 @@ twoxsix_udtf = udtf( input_types=[IntegerType()], ) -df.select(df.a, twoxsix_udtf(df.a)).collect() +df.select(df.a, session.table_function(twoxsix_udtf(df.a))[0]).collect() -df.select("a", twoxsix_udtf("a").alias("double", "six_x"), "c").collect() +df.select("a", session.table_function(twoxsix_udtf("a"))[0].alias("six_x"), "c").collect() ## EXPECTED UNPARSER OUTPUT @@ -107,7 +107,7 @@ session.table_function(prime_udtf(lit(20))).collect() df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) -df.with_column("total", udtf("sum_udtf", output_schema=StructType([StructField("number", LongType(), nullable=True)], structured=False), input_types=[LongType(), LongType()], copy_grants=False, _registered_object_name="\"MOCK_DATABASE\".\"MOCK_SCHEMA\".SNOWPARK_TEMP_TABLE_FUNCTION_xxx")(df["a"], df["b"])).sort(df["a"]).show() +df.with_column("total", udtf("sum_udtf", output_schema=StructType([StructField("number", LongType(), nullable=True)], structured=False), input_types=[LongType(), LongType()], copy_grants=False, _registered_object_name="\"MOCK_DATABASE\".\"MOCK_SCHEMA\".SNOWPARK_TEMP_TABLE_FUNCTION_xxx")(df["a"], df["b"])).sort(df["a"], ascending=None).show() generator_udtf = udtf("GeneratorUDTF", output_schema=StructType([StructField("number", IntegerType(), nullable=True)], structured=False), input_types=[IntegerType()], copy_grants=False, _registered_object_name="\"MOCK_DATABASE\".\"MOCK_SCHEMA\".SNOWPARK_TEMP_TABLE_FUNCTION_xxx") @@ -121,13 +121,13 @@ df = df.to_df(["a", "b", "c"]) twox_udtf = udtf("TwoXUDTF", output_schema=StructType([StructField("two_x", IntegerType(), nullable=True)], structured=False), input_types=[IntegerType()], copy_grants=False, _registered_object_name="\"MOCK_DATABASE\".\"MOCK_SCHEMA\".SNOWPARK_TEMP_TABLE_FUNCTION_xxx") -df.select((twox_udtf("a"))).collect() +df.select(session.table_function(twox_udtf("a"))["TWO_X"]).collect() twoxsix_udtf = udtf("TwoXSixXUDTF", output_schema=StructType([StructField("two_x", IntegerType(), nullable=True), StructField("six_x", IntegerType(), nullable=True)], structured=False), input_types=[IntegerType()], copy_grants=False, _registered_object_name="\"MOCK_DATABASE\".\"MOCK_SCHEMA\".SNOWPARK_TEMP_TABLE_FUNCTION_xxx") -df.select(df["a"], (twoxsix_udtf(df["a"]))).collect() +df.select(df["a"], session.table_function(twoxsix_udtf(df["a"]))["TWO_X"]).collect() -df.select(col("a"), (twoxsix_udtf("a").alias("double", "six_x")), col("c")).collect() +df.select(col("a"), session.table_function(twoxsix_udtf("a"))["TWO_X"].alias("six_x"), col("c")).collect() ## EXPECTED ENCODED AST @@ -166,13 +166,15 @@ body { return_type { sp_struct_type { fields { - column_identifier { - name: "number" - } - data_type { - sp_integer_type: true + list { + column_identifier { + name: "number" + } + data_type { + sp_integer_type: true + } + nullable: true } - nullable: true } } } @@ -350,13 +352,15 @@ body { return_type { sp_struct_type { fields { - column_identifier { - name: "number" - } - data_type { - sp_long_type: true + list { + column_identifier { + name: "number" + } + data_type { + sp_long_type: true + } + nullable: true } - nullable: true } } } @@ -535,6 +539,14 @@ body { assign { expr { sp_dataframe_sort { + ascending { + null_val { + src { + file: "SRC_POSITION_TEST_MODE" + start_line: 56 + } + } + } cols { sp_dataframe_col { col_name: "a" @@ -634,13 +646,15 @@ body { return_type { sp_struct_type { fields { - column_identifier { - name: "number" - } - data_type { - sp_integer_type: true + list { + column_identifier { + name: "number" + } + data_type { + sp_integer_type: true + } + nullable: true } - nullable: true } } } @@ -858,22 +872,24 @@ body { return_type { sp_struct_type { fields { - column_identifier { - name: "a" - } - data_type { - sp_long_type: true - } - nullable: true - } - fields { - column_identifier { - name: "b" + list { + column_identifier { + name: "a" + } + data_type { + sp_long_type: true + } + nullable: true } - data_type { - sp_long_type: true + list { + column_identifier { + name: "b" + } + data_type { + sp_long_type: true + } + nullable: true } - nullable: true } } } @@ -1101,13 +1117,15 @@ body { return_type { sp_struct_type { fields { - column_identifier { - name: "two_x" - } - data_type { - sp_integer_type: true + list { + column_identifier { + name: "two_x" + } + data_type { + sp_integer_type: true + } + nullable: true } - nullable: true } } } @@ -1166,8 +1184,8 @@ body { body { assign { expr { - sp_dataframe_select__columns { - cols { + sp_session_table_function { + fn { apply_expr { fn { indirect_table_fn_id_ref { @@ -1182,6 +1200,40 @@ body { } } } + src { + file: "SRC_POSITION_TEST_MODE" + start_line: 106 + } + } + } + symbol { + } + uid: 22 + var_id { + bitfield1: 22 + } + } +} +body { + assign { + expr { + sp_dataframe_select__columns { + cols { + sp_dataframe_col { + col_name: "TWO_X" + df { + sp_dataframe_ref { + id { + bitfield1: 22 + } + } + } + src { + file: "SRC_POSITION_TEST_MODE" + start_line: 106 + } + } + } df { sp_dataframe_ref { id { @@ -1198,9 +1250,9 @@ body { } symbol { } - uid: 22 + uid: 23 var_id { - bitfield1: 22 + bitfield1: 23 } } } @@ -1211,7 +1263,7 @@ body { block: true case_sensitive: true id { - bitfield1: 22 + bitfield1: 23 } src { file: "SRC_POSITION_TEST_MODE" @@ -1221,17 +1273,17 @@ body { } symbol { } - uid: 23 + uid: 24 var_id { - bitfield1: 23 + bitfield1: 24 } } } body { eval { - uid: 24 + uid: 25 var_id { - bitfield1: 23 + bitfield1: 24 } } } @@ -1271,22 +1323,24 @@ body { return_type { sp_struct_type { fields { - column_identifier { - name: "two_x" - } - data_type { - sp_integer_type: true - } - nullable: true - } - fields { - column_identifier { - name: "six_x" + list { + column_identifier { + name: "two_x" + } + data_type { + sp_integer_type: true + } + nullable: true } - data_type { - sp_integer_type: true + list { + column_identifier { + name: "six_x" + } + data_type { + sp_integer_type: true + } + nullable: true } - nullable: true } } } @@ -1302,9 +1356,9 @@ body { symbol { value: "twoxsix_udtf" } - uid: 25 + uid: 26 var_id { - bitfield1: 25 + bitfield1: 26 } } } @@ -1315,7 +1369,7 @@ body { fn { sp_fn_ref { id { - bitfield1: 25 + bitfield1: 26 } } } @@ -1343,9 +1397,42 @@ body { } symbol { } - uid: 26 + uid: 27 var_id { - bitfield1: 26 + bitfield1: 27 + } + } +} +body { + assign { + expr { + sp_session_table_function { + fn { + apply_expr { + fn { + indirect_table_fn_id_ref { + id { + bitfield1: 27 + } + } + } + src { + file: "SRC_POSITION_TEST_MODE" + start_line: 120 + } + } + } + src { + file: "SRC_POSITION_TEST_MODE" + start_line: 120 + } + } + } + symbol { + } + uid: 28 + var_id { + bitfield1: 28 } } } @@ -1370,11 +1457,12 @@ body { } } cols { - apply_expr { - fn { - indirect_table_fn_id_ref { + sp_dataframe_col { + col_name: "TWO_X" + df { + sp_dataframe_ref { id { - bitfield1: 26 + bitfield1: 28 } } } @@ -1400,9 +1488,9 @@ body { } symbol { } - uid: 27 + uid: 29 var_id { - bitfield1: 27 + bitfield1: 29 } } } @@ -1413,7 +1501,7 @@ body { block: true case_sensitive: true id { - bitfield1: 27 + bitfield1: 29 } src { file: "SRC_POSITION_TEST_MODE" @@ -1423,63 +1511,67 @@ body { } symbol { } - uid: 28 + uid: 30 var_id { - bitfield1: 28 + bitfield1: 30 } } } body { eval { - uid: 29 + uid: 31 var_id { - bitfield1: 28 + bitfield1: 30 } } } body { assign { expr { - sp_table_fn_call_alias { - aliases { - args { - string_val { - src { - file: "SRC_POSITION_TEST_MODE" - start_line: 122 - } - v: "double" + apply_expr { + fn { + sp_fn_ref { + id { + bitfield1: 26 } } - args { - string_val { - src { - file: "SRC_POSITION_TEST_MODE" - start_line: 122 - } - v: "six_x" + } + pos_args { + string_val { + src { + file: "SRC_POSITION_TEST_MODE" + start_line: 122 } + v: "a" } - variadic: true } - lhs { + src { + file: "SRC_POSITION_TEST_MODE" + start_line: 122 + } + } + } + symbol { + } + uid: 32 + var_id { + bitfield1: 32 + } + } +} +body { + assign { + expr { + sp_session_table_function { + fn { apply_expr { fn { - sp_fn_ref { + indirect_table_fn_id_ref { id { - bitfield1: 25 + bitfield1: 32 } } } - pos_args { - string_val { - src { - file: "SRC_POSITION_TEST_MODE" - start_line: 122 - } - v: "a" - } - } src { file: "SRC_POSITION_TEST_MODE" start_line: 122 @@ -1494,9 +1586,9 @@ body { } symbol { } - uid: 30 + uid: 33 var_id { - bitfield1: 30 + bitfield1: 33 } } } @@ -1533,14 +1625,27 @@ body { } } cols { - apply_expr { - fn { - indirect_table_fn_id_ref { - id { - bitfield1: 30 + sp_column_alias { + col { + sp_dataframe_col { + col_name: "TWO_X" + df { + sp_dataframe_ref { + id { + bitfield1: 33 + } + } + } + src { + file: "SRC_POSITION_TEST_MODE" + start_line: 122 } } } + fn { + sp_column_alias_fn_alias: true + } + name: "six_x" src { file: "SRC_POSITION_TEST_MODE" start_line: 122 @@ -1591,9 +1696,9 @@ body { } symbol { } - uid: 31 + uid: 34 var_id { - bitfield1: 31 + bitfield1: 34 } } } @@ -1604,7 +1709,7 @@ body { block: true case_sensitive: true id { - bitfield1: 31 + bitfield1: 34 } src { file: "SRC_POSITION_TEST_MODE" @@ -1614,17 +1719,17 @@ body { } symbol { } - uid: 32 + uid: 35 var_id { - bitfield1: 32 + bitfield1: 35 } } } body { eval { - uid: 33 + uid: 36 var_id { - bitfield1: 32 + bitfield1: 35 } } } diff --git a/tests/ast/decoder.py b/tests/ast/decoder.py index 8f28acdf4b..6b98a68247 100644 --- a/tests/ast/decoder.py +++ b/tests/ast/decoder.py @@ -18,7 +18,7 @@ from snowflake.snowpark.relational_grouped_dataframe import GroupingSets from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row import snowflake.snowpark.functions -from snowflake.snowpark.functions import udf, when, sproc, call_table_function +from snowflake.snowpark.functions import udf, udtf, when, sproc, call_table_function from snowflake.snowpark.types import ( DataType, ArrayType, @@ -110,6 +110,42 @@ def get_dataframe_analytics_function_column_formatter( else: return DataFrameAnalyticsFunctions._default_col_formatter + def decode_callable_expr( + self, callable_expr: proto.SpCallable + ) -> Tuple[Callable, str]: + """ + Decode a callable expression to get the callable. + + Parameters + ---------- + callable_expr : proto.SpCallable + The callable expression to decode. + + Returns + ------- + Tuple[Callable, str] + The decoded callable and its associated name. + """ + id = callable_expr.id + name = callable_expr.name + object_name = ( + self.decode_name_expr(callable_expr.object_name) + if callable_expr.HasField("object_name") + else None + ) + try: + handler = self.session._udtf_registration.get_udtf(object_name).handler + except KeyError: + + def __temp_handler_func(): + pass + + __temp_handler_func.__name__ = ( + name # Set the name of the function to whatever it was originally. + ) + handler, object_name = __temp_handler_func, name + return handler, object_name + def decode_col_exprs(self, expr: proto.Expr) -> List[Column]: """ Decode a protobuf object to a list of column expressions. @@ -188,7 +224,7 @@ def decode_name_expr(self, table_name: proto.SpName) -> Union[str, List]: elif table_name.name.HasField("sp_name_structured"): return [name for name in table_name.name.sp_name_structured.name] else: - raise ValueError("Table name not found in proto.SpTableName") + raise ValueError("Table name not found in proto.SpName") def decode_fn_ref_expr(self, fn_ref_expr: proto.FnRefExpr) -> str: """ @@ -308,16 +344,9 @@ def decode_dataframe_schema_expr( else: return None case "sp_dataframe_schema__struct": - struct_field_list = [] - for field in df_schema_expr.sp_dataframe_schema__struct.v.fields.list: - column_identifier = field.column_identifier.name - datatype = self.decode_data_type_expr(field.data_type) - nullable = field.nullable - struct_field_list.append( - StructField(column_identifier, datatype, nullable) - ) - structured = df_schema_expr.sp_dataframe_schema__struct.v.structured - return StructType(struct_field_list, structured) + return self.decode_struct_type_expr( + df_schema_expr.sp_dataframe_schema__struct.v + ) case _: raise ValueError( "Unknown dataframe schema type: %s" @@ -388,10 +417,9 @@ def decode_data_type_expr( if isinstance( data_type_expr.sp_pandas_data_frame_type.col_types, Iterable ): - col_types = [ - col_name - for col_name in data_type_expr.sp_pandas_data_frame_type.col_types - ] + col_types = [] + for col_type in data_type_expr.sp_pandas_data_frame_type.col_types: + col_types.append(self.decode_data_type_expr(col_type)) else: col_types = [data_type_expr.sp_pandas_data_frame_type.col_types] if isinstance( @@ -435,18 +463,12 @@ def decode_data_type_expr( return StructField(column_identifier, data_type, nullable) case "sp_struct_type": # The fields can be a list of Expr, a single Expr, or None. + fields = [] if hasattr(data_type_expr.sp_struct_type, "fields"): - if isinstance(data_type_expr.sp_struct_type.fields, Iterable): - fields = [ - self.decode_data_type_expr(field) - for field in data_type_expr.sp_struct_type.fields - ] - else: - fields = [ - self.decode_data_type_expr( - data_type_expr.sp_struct_type.fields - ) - ] + for field in data_type_expr.sp_struct_type.fields.list: + column_identifier = field.column_identifier.name + data_type = self.decode_data_type_expr(field.data_type) + fields.append(StructField(column_identifier, data_type)) else: fields = None structured = data_type_expr.sp_struct_type.structured @@ -529,6 +551,56 @@ def decode_join_type(self, join_type: proto.SpJoinType) -> str: "Unknown join type: %s" % join_type.WhichOneof("variant") ) + def decode_pivot_value_expr(self, pivot_value_expr: proto.SpPivotValue) -> Any: + """ + Decode expr to get the pivot value. + + Parameters + ---------- + pivot_value_expr : proto.SpPivotValues + The expression to decode. + + Returns + ------- + Any + The decoded pivot value. + """ + match pivot_value_expr.WhichOneof("sealed_value"): + case "sp_pivot_value__dataframe": + return self.decode_expr(pivot_value_expr.sp_pivot_value__dataframe.v) + case "sp_pivot_value__expr": + return self.decode_expr(pivot_value_expr.sp_pivot_value__expr.v) + case _: + raise ValueError( + "Unknown pivot value: %s" + % pivot_value_expr.WhichOneof("sealed_value") + ) + + def decode_struct_type_expr( + self, sp_struct_type_expr: proto.SpStructType + ) -> StructType: + """ + Decode a struct type expression to get the struct type. + + Parameters + ---------- + struct_type_expr : proto.SpStructType + The expression to decode. + + Returns + ------- + StructType + The decoded object. + """ + struct_field_list = [] + for field in sp_struct_type_expr.fields.list: + column_identifier = field.column_identifier.name + datatype = self.decode_data_type_expr(field.data_type) + nullable = field.nullable + struct_field_list.append(StructField(column_identifier, datatype, nullable)) + structured = sp_struct_type_expr.structured + return StructType(struct_field_list, structured) + def decode_timezone_expr(self, tz_expr: proto.PythonTimeZone) -> Any: """ Decode a Python timezone expression to get the timezone. @@ -542,6 +614,35 @@ def decode_timezone_expr(self, tz_expr: proto.PythonTimeZone) -> Any: offset_seconds = tz_expr.offset_seconds return timezone(offset=timedelta(seconds=offset_seconds), name=tz_name) + def decode_udtf_schema( + self, udtf_schema: proto.UdtfSchema + ) -> Union[List, DataType]: + """ + Decode a UDTF schema expression to get the schema. + + Parameters + ---------- + udtf_schema : proto.UdtfSchema + The expression to decode. + + Returns + ------- + List or DataType + The decoded schema. + """ + match udtf_schema.WhichOneof("sealed_value"): + case "udtf_schema__names": + return [s for s in udtf_schema.udtf_schema__names.schema] + case "udtf_schema__type": + return self.decode_data_type_expr( + udtf_schema.udtf_schema__type.return_type + ) + case _: + raise ValueError( + "Unknown UDTF schema type: %s" + % udtf_schema.WhichOneof("sealed_value") + ) + def decode_window_spec_expr(self, window_spec_expr: proto.SpWindowSpecExpr) -> Any: """ Decode a window specification expression. @@ -1447,20 +1548,7 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: default_on_null = self.decode_expr( expr.sp_dataframe_pivot.default_on_null ) - match expr.sp_dataframe_pivot.values.WhichOneof("sealed_value"): - case "sp_pivot_value__dataframe": - values = self.decode_expr( - expr.sp_dataframe_pivot.values.sp_pivot_value__dataframe.v - ) - case "sp_pivot_value__expr": - values = self.decode_expr( - expr.sp_dataframe_pivot.values.sp_pivot_value__expr.v - ) - case _: - raise ValueError( - "Unknown pivot value: %s" - % expr.sp_dataframe_pivot.values.WhichOneof("sealed_value") - ) + values = self.decode_pivot_value_expr(expr.sp_dataframe_pivot.values) return df.pivot(pivot_col, values, default_on_null) case "sp_dataframe_random_split": @@ -1527,9 +1615,7 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: return df.select_expr(exprs) case "sp_dataframe_show": - df = self.decode_expr( - self.symbol_table[expr.sp_dataframe_show.id.bitfield1][1] - ) + df = self.symbol_table[expr.sp_dataframe_show.id.bitfield1][1] return df.show() case "sp_dataframe_sort": @@ -1667,8 +1753,8 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: ] name_column = expr.sp_dataframe_unpivot.name_column value_column = expr.sp_dataframe_unpivot.value_column - # TODO SNOW-1866100: add logic for `include_nulls`. - return df.unpivot(value_column, name_column, column_list) + include_nulls = expr.sp_dataframe_unpivot.include_nulls + return df.unpivot(value_column, name_column, column_list, include_nulls) case "sp_dataframe_with_column": df = self.decode_expr(expr.sp_dataframe_with_column.df) @@ -1711,14 +1797,19 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: return self.session.range(start, end, step) case "sp_relational_grouped_dataframe_apply_in_pandas": - # TODO: SNOW-1830603 Flesh out this logic when implementing UDTFs. Need to create a dict to maintain - # all functions registered (here, `func`). Implement `decode_callable_expr`. - # func = self.decode_callable_expr(expr.sp_relational_grouped_dataframe_apply_in_pandas.func) - # grouped_df = self.decode_expr(expr.sp_relational_grouped_dataframe_apply_in_pandas.grouped_df) - # kwargs = self.decode_dsl_map_expr(expr.sp_relational_grouped_dataframe_apply_in_pandas.kwargs) - # output_schema = self.decode_expr(expr.sp_relational_grouped_dataframe_apply_in_pandas.output_schema) - # return grouped_df.apply_in_pandas(func, output_schema, **kwargs) - pass + func, _ = self.decode_callable_expr( + expr.sp_relational_grouped_dataframe_apply_in_pandas.func + ) + grouped_df = self.decode_expr( + expr.sp_relational_grouped_dataframe_apply_in_pandas.grouped_df + ) + kwargs = self.decode_dsl_map_expr( + expr.sp_relational_grouped_dataframe_apply_in_pandas.kwargs + ) + output_schema = self.decode_struct_type_expr( + expr.sp_relational_grouped_dataframe_apply_in_pandas.output_schema + ) + return grouped_df.apply_in_pandas(func, output_schema, **kwargs) case "sp_relational_grouped_dataframe_builtin": grouped_df = self.decode_expr( @@ -1739,6 +1830,31 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: else: return getattr(grouped_df, agg_name)(cols) + case "sp_relational_grouped_dataframe_pivot": + default_on_null = ( + self.decode_expr( + expr.sp_relational_grouped_dataframe_pivot.default_on_null + ) + if expr.sp_relational_grouped_dataframe_pivot.HasField( + "default_on_null" + ) + else None + ) + grouped_df = self.decode_expr( + expr.sp_relational_grouped_dataframe_pivot.grouped_df + ) + pivot_col = self.decode_expr( + expr.sp_relational_grouped_dataframe_pivot.pivot_col + ) + values = ( + self.decode_pivot_value_expr( + expr.sp_relational_grouped_dataframe_pivot.values + ) + if expr.sp_relational_grouped_dataframe_pivot.HasField("values") + else None + ) + return grouped_df.pivot(pivot_col, values, default_on_null) + case "sp_relational_grouped_dataframe_ref": return self.symbol_table[ expr.sp_relational_grouped_dataframe_ref.id.bitfield1 @@ -1790,8 +1906,64 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: ) case "udtf": - # TODO: SNOW-1830603 Implement UDTF decoding. - pass + comment = ( + expr.udtf.comment.value if expr.udtf.HasField("comment") else None + ) + external_access_integrations = [ + eai for eai in expr.udtf.external_access_integrations + ] + handler, handler_name = self.decode_callable_expr(expr.udtf.handler) + if_not_exists = expr.udtf.if_not_exists + immutable = expr.udtf.immutable + imports = [ + self.decode_name_expr(import_) for import_ in expr.udtf.imports + ] + input_types = [ + self.decode_data_type_expr(input_type) + for input_type in expr.udtf.input_types.list + ] + is_permanent = expr.udtf.is_permanent + kwargs = self.decode_dsl_map_expr(expr.udtf.kwargs) + if "copy_grants" in kwargs: + kwargs.pop("copy_grants") + name = ( + self.decode_name_expr(expr.udtf.name) + if expr.udtf.HasField("name") + else None + ) + output_schema = self.decode_udtf_schema(expr.udtf.output_schema) + packages = [package for package in expr.udtf.packages] + parallel = expr.udtf.parallel + replace = expr.udtf.replace + secrets = self.decode_dsl_map_expr(expr.udtf.secrets) + secure = expr.udtf.secure + stage_location = expr.udtf.stage_location + statement_params = self.decode_dsl_map_expr(expr.udtf.statement_params) + strict = expr.udtf.strict + # Run udtf to create the required AST but return the first registered version of the UDTF. + _ = udtf( + handler, + output_schema=output_schema, + input_types=input_types, + name=name, + is_permanent=is_permanent, + stage_location=stage_location, + imports=imports, + packages=packages, + replace=replace, + if_not_exists=if_not_exists, + session=self.session, + parallel=parallel, + statement_params=statement_params, + strict=strict, + secure=secure, + external_access_integrations=external_access_integrations, + secrets=secrets, + immutable=immutable, + comment=comment, + **kwargs, + ) + return self.session._udtf_registration.get_udtf(handler_name) case "sp_dataframe_cross_join": lhs = self.decode_expr(expr.sp_dataframe_cross_join.lhs) diff --git a/tests/ast/test_ast_driver.py b/tests/ast/test_ast_driver.py index 95b8939e68..d0afe89e51 100644 --- a/tests/ast/test_ast_driver.py +++ b/tests/ast/test_ast_driver.py @@ -240,7 +240,9 @@ def compare_base64_results( actual_message = actual_message.SerializeToString(deterministic=True) expected_message = expected_message.SerializeToString(deterministic=True) - assert actual_message == expected_message + assert normalize_temp_names(actual_message) == normalize_temp_names( + expected_message + ) @pytest.mark.parametrize("test_case", load_test_cases(), ids=idfn) @@ -290,6 +292,7 @@ def test_ast(session, tables, test_case): # version of the Snowpark code. decoder = Decoder(session) session._ast_batch.reset_id_gen() # Reset the entity ID generator. + session._ast_batch.reset_callables() # Reset the callables. session._ast_batch.flush() # Clear the AST. global_counter.reset()