From 856cc9917e07cf23c183dd43c369b7f23eab1f5b Mon Sep 17 00:00:00 2001 From: Varnika Budati Date: Fri, 24 Jan 2025 16:18:40 -0800 Subject: [PATCH 1/2] update tests, add decoder functionality for dataframe.write.test, dataframe.count2.test, fixed col_in decoder --- tests/ast/data/DataFrame.count2.test | 34 +-- .../data/Dataframe.to_snowpark_pandas.test | 15 - tests/ast/data/Table.merge.test | 32 ++- tests/ast/data/case_when.test | 132 ++++++--- tests/ast/data/col_cast.test | 2 +- tests/ast/data/col_cast_coll.test | 86 +++--- tests/ast/data/col_in_.test | 24 +- tests/ast/data/col_try_cast.test | 2 +- tests/ast/data/session.read.test | 48 ++-- tests/ast/data/sproc.test | 34 +-- tests/ast/decoder.py | 266 +++++++++++++++++- 11 files changed, 480 insertions(+), 195 deletions(-) diff --git a/tests/ast/data/DataFrame.count2.test b/tests/ast/data/DataFrame.count2.test index cbba823aec..8e7944efe0 100644 --- a/tests/ast/data/DataFrame.count2.test +++ b/tests/ast/data/DataFrame.count2.test @@ -37,26 +37,28 @@ body { sp_dataframe_schema__struct { v { fields { - column_identifier { - name: "\"A\"" - } - data_type { - sp_string_type { - length { - value: 16777216 + list { + column_identifier { + name: "\"A\"" + } + data_type { + sp_string_type { + length { + value: 16777216 + } } } + nullable: true } - nullable: true - } - fields { - column_identifier { - name: "\"B\"" - } - data_type { - sp_long_type: true + list { + column_identifier { + name: "\"B\"" + } + data_type { + sp_long_type: true + } + nullable: true } - nullable: true } } } diff --git a/tests/ast/data/Dataframe.to_snowpark_pandas.test b/tests/ast/data/Dataframe.to_snowpark_pandas.test index 4d5db707a9..f79f444311 100644 --- a/tests/ast/data/Dataframe.to_snowpark_pandas.test +++ b/tests/ast/data/Dataframe.to_snowpark_pandas.test @@ -37,10 +37,7 @@ body { } } src { - end_column: 41 - end_line: 25 file: "SRC_POSITION_TEST_MODE" - start_column: 13 start_line: 25 } variant { @@ -69,10 +66,7 @@ body { } } src { - end_column: 52 - end_line: 27 file: "SRC_POSITION_TEST_MODE" - start_column: 29 start_line: 27 } } @@ -101,10 +95,7 @@ body { list: "A" } src { - end_column: 65 - end_line: 29 file: "SRC_POSITION_TEST_MODE" - start_column: 29 start_line: 29 } } @@ -134,10 +125,7 @@ body { } } src { - end_column: 70 - end_line: 31 file: "SRC_POSITION_TEST_MODE" - start_column: 29 start_line: 31 } } @@ -170,10 +158,7 @@ body { list: "A" } src { - end_column: 87 - end_line: 33 file: "SRC_POSITION_TEST_MODE" - start_column: 29 start_line: 33 } } diff --git a/tests/ast/data/Table.merge.test b/tests/ast/data/Table.merge.test index 72e1c8be45..65af8647c5 100644 --- a/tests/ast/data/Table.merge.test +++ b/tests/ast/data/Table.merge.test @@ -165,25 +165,27 @@ body { sp_dataframe_schema__struct { v { fields { - column_identifier { - name: "num" - } - data_type { - sp_integer_type: true - } - nullable: true - } - fields { - column_identifier { - name: "str" + list { + column_identifier { + name: "num" + } + data_type { + sp_integer_type: true + } + nullable: true } - data_type { - sp_string_type { - length { + list { + column_identifier { + name: "str" + } + data_type { + sp_string_type { + length { + } } } + nullable: true } - nullable: true } } } diff --git a/tests/ast/data/case_when.test b/tests/ast/data/case_when.test index 146495780e..afddeb10ff 100644 --- a/tests/ast/data/case_when.test +++ b/tests/ast/data/case_when.test @@ -57,8 +57,10 @@ body { expr { sp_table { name { - sp_table_name_flat { - name: "table1" + name { + sp_name_flat { + name: "table1" + } } } src { @@ -91,8 +93,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -169,8 +173,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -227,8 +233,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -335,8 +343,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -389,8 +399,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -421,8 +433,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -486,8 +500,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -585,8 +601,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -621,8 +639,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "lit" + name { + sp_name_flat { + name: "lit" + } } } } @@ -651,8 +671,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -696,8 +718,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "lit" + name { + sp_name_flat { + name: "lit" + } } } } @@ -728,8 +752,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "lit" + name { + sp_name_flat { + name: "lit" + } } } } @@ -793,8 +819,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -866,8 +894,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -959,8 +989,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -985,8 +1017,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -1084,8 +1118,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -1157,8 +1193,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -1230,8 +1268,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -1303,8 +1343,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -1376,8 +1418,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } diff --git a/tests/ast/data/col_cast.test b/tests/ast/data/col_cast.test index b2d8db5ee4..efb6429c41 100644 --- a/tests/ast/data/col_cast.test +++ b/tests/ast/data/col_cast.test @@ -93,7 +93,7 @@ df = df.select(col("A").cast(MapType(StringType(), StringType(), structured=Fals df = df.select(col("A").cast(VectorType(FloatType(), 42))) -df = df.select(col("A").cast(StructType([], structured=False))) +df = df.select(col("A").cast(StructType(structured=False))) df = df.select(col("A").cast(VariantType())) diff --git a/tests/ast/data/col_cast_coll.test b/tests/ast/data/col_cast_coll.test index a4e3baad7f..9aa77aa993 100644 --- a/tests/ast/data/col_cast_coll.test +++ b/tests/ast/data/col_cast_coll.test @@ -611,61 +611,63 @@ body { to { sp_struct_type { fields { - column_identifier { - name: "test1" - } - data_type { - sp_string_type { - length { + list { + column_identifier { + name: "test1" + } + data_type { + sp_string_type { + length { + } } } } - } - fields { - column_identifier { - name: "test2" - } - data_type { - sp_integer_type: true - } - nullable: true - } - fields { - column_identifier { - name: "test3" + list { + column_identifier { + name: "test2" + } + data_type { + sp_integer_type: true + } + nullable: true } - data_type { - sp_array_type { - ty { - sp_long_type: true + list { + column_identifier { + name: "test3" + } + data_type { + sp_array_type { + ty { + sp_long_type: true + } } } + nullable: true } - nullable: true - } - fields { - column_identifier { - name: "test4" - } - data_type { - sp_map_type { - key_ty { - sp_decimal_type { - precision: 42 - scale: 23 + list { + column_identifier { + name: "test4" + } + data_type { + sp_map_type { + key_ty { + sp_decimal_type { + precision: 42 + scale: 23 + } } - } - value_ty { - sp_vector_type { - dimension: 64 - ty { - sp_float_type: true + value_ty { + sp_vector_type { + dimension: 64 + ty { + sp_float_type: true + } } } } } + nullable: true } - nullable: true } structured: true } diff --git a/tests/ast/data/col_in_.test b/tests/ast/data/col_in_.test index ff21d16f1b..6e2a219361 100644 --- a/tests/ast/data/col_in_.test +++ b/tests/ast/data/col_in_.test @@ -25,8 +25,10 @@ body { expr { sp_table { name { - sp_table_name_flat { - name: "table1" + name { + sp_name_flat { + name: "table1" + } } } src { @@ -58,8 +60,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -142,8 +146,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } @@ -206,8 +212,10 @@ body { fn { builtin_fn { name { - fn_name_flat { - name: "col" + name { + sp_name_flat { + name: "col" + } } } } diff --git a/tests/ast/data/col_try_cast.test b/tests/ast/data/col_try_cast.test index 38325f9b92..0ffd5185fc 100644 --- a/tests/ast/data/col_try_cast.test +++ b/tests/ast/data/col_try_cast.test @@ -93,7 +93,7 @@ df = df.select(col("A").try_cast(MapType(StringType(), StringType(), structured= df = df.select(col("A").try_cast(VectorType(FloatType(), 42))) -df = df.select(col("A").try_cast(StructType([], structured=False))) +df = df.select(col("A").try_cast(StructType(structured=False))) df = df.select(col("A").try_cast(VariantType())) diff --git a/tests/ast/data/session.read.test b/tests/ast/data/session.read.test index c618d82fff..caf39174a7 100644 --- a/tests/ast/data/session.read.test +++ b/tests/ast/data/session.read.test @@ -399,34 +399,36 @@ body { } schema { fields { - column_identifier { - name: "a" - } - data_type { - sp_integer_type: true - } - nullable: true - } - fields { - column_identifier { - name: "b" + list { + column_identifier { + name: "a" + } + data_type { + sp_integer_type: true + } + nullable: true } - data_type { - sp_string_type { - length { + list { + column_identifier { + name: "b" + } + data_type { + sp_string_type { + length { + } } } + nullable: true } - nullable: true - } - fields { - column_identifier { - name: "c" - } - data_type { - sp_float_type: true + list { + column_identifier { + name: "c" + } + data_type { + sp_float_type: true + } + nullable: true } - nullable: true } } src { diff --git a/tests/ast/data/sproc.test b/tests/ast/data/sproc.test index cab0b273bb..b5c5e02ee7 100644 --- a/tests/ast/data/sproc.test +++ b/tests/ast/data/sproc.test @@ -173,11 +173,11 @@ res33 = sproc("select_sp", return_type=StructType([StructField("A", IntegerType( session.sql("SELECT 1 as A, 2 as B").show() -res37 = sproc("select_sp", return_type=StructType([], structured=False), input_types=[IntegerType(), IntegerType()], source_code_display=False, _registered_object_name="\"MOCK_DATABASE\".\"MOCK_SCHEMA\".\"SNOWPARK_TEMP_PROCEDURE_xxx\"")(1, 2) +res37 = sproc("select_sp", return_type=StructType(structured=False), input_types=[IntegerType(), IntegerType()], source_code_display=False, _registered_object_name="\"MOCK_DATABASE\".\"MOCK_SCHEMA\".\"SNOWPARK_TEMP_PROCEDURE_xxx\"")(1, 2) session.sql("SELECT 1 as A, 2 as B").show() -res41 = sproc("select_sp", return_type=StructType([], structured=False), input_types=[LongType(), LongType()], _registered_object_name="\"MOCK_DATABASE\".\"MOCK_SCHEMA\".\"SNOWPARK_TEMP_PROCEDURE_xxx\"")(1, 2) +res41 = sproc("select_sp", return_type=StructType(structured=False), input_types=[LongType(), LongType()], _registered_object_name="\"MOCK_DATABASE\".\"MOCK_SCHEMA\".\"SNOWPARK_TEMP_PROCEDURE_xxx\"")(1, 2) session.sql("SELECT 1 as A, 2 as B").show() @@ -1603,22 +1603,24 @@ body { return_type { sp_struct_type { fields { - column_identifier { - name: "A" - } - data_type { - sp_integer_type: true - } - nullable: true - } - fields { - column_identifier { - name: "B" + list { + column_identifier { + name: "A" + } + data_type { + sp_integer_type: true + } + nullable: true } - data_type { - sp_integer_type: true + list { + column_identifier { + name: "B" + } + data_type { + sp_integer_type: true + } + nullable: true } - nullable: true } } } diff --git a/tests/ast/decoder.py b/tests/ast/decoder.py index 9c61c0a965..e42f2b582e 100644 --- a/tests/ast/decoder.py +++ b/tests/ast/decoder.py @@ -10,6 +10,7 @@ from pandas import DataFrame as PandasDataFrame +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SaveMode from snowflake.snowpark.window import WindowSpec, Window, WindowRelativePosition import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto @@ -319,8 +320,8 @@ def decode_dataframe_data_expr( else: return [] case "sp_dataframe_data__pandas": - # We don't know what pandas DataFrame was passed in, return an empty one. - return PandasDataFrame() + # We don't know what pandas DataFrame was passed in, return a non-empty one. + return PandasDataFrame({"A": ["1", "2"], "B": [4, 5]}) # case "sp_dataframe_data__tuple": # pass case _: @@ -590,6 +591,36 @@ def decode_pivot_value_expr(self, pivot_value_expr: proto.SpPivotValue) -> Any: % pivot_value_expr.WhichOneof("sealed_value") ) + def decode_save_mode(self, save_mode: proto.SpSaveMode) -> str: + """ + Decode a save mode expression to get the save mode. + + Parameters + ---------- + save_mode : proto.SpSaveMode + The expression to decode. + + Returns + ------- + str + The decoded save mode. + """ + match save_mode.WhichOneof("variant"): + case "sp_save_mode_append": + return "append" + case "sp_save_mode_error_if_exists": + return "errorifexists" + case "sp_save_mode_ignore": + return "ignore" + case "sp_save_mode_overwrite": + return "overwrite" + case "sp_save_mode_truncate": + return "truncate" + case _: + raise ValueError( + "Unknown save mode: %s" % save_mode.WhichOneof("variant") + ) + def decode_struct_type_expr( self, sp_struct_type_expr: proto.SpStructType ) -> StructType: @@ -1047,12 +1078,12 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: rhs = self.decode_expr(expr.sp_column_equal_null.rhs) return lhs.equal_null(rhs) - case "sp_column_in__seq": - col = self.decode_expr(expr.sp_column_in__seq.col) - if isinstance(expr.sp_column_in__seq.values, Iterable): + case "sp_column_in": + col = self.decode_expr(expr.sp_column_in.col) + if isinstance(expr.sp_column_in.values, Iterable): # The values should be passed in as positional arguments and not as a list. return col.in_( - self.decode_expr(v) for v in expr.sp_column_in__seq.values + *[self.decode_expr(v) for v in expr.sp_column_in.values] ) else: # The list case should be taken care of in this branch. @@ -1089,13 +1120,17 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: case "sp_column_string_regexp": col = self.decode_expr(expr.sp_column_string_regexp.col) pattern = self.decode_expr(expr.sp_column_string_regexp.pattern) - parameters = self.decode_expr(expr.sp_column_string_regexp.parameters) + parameters = ( + self.decode_expr(expr.sp_column_string_regexp.parameters) + if expr.sp_column_string_regexp.HasField("parameters") + else None + ) return col.regexp(pattern, parameters) case "sp_column_string_starts_with": col = self.decode_expr(expr.sp_column_string_starts_with.col) prefix = self.decode_expr(expr.sp_column_string_starts_with.prefix) - return col.starts_with(prefix) + return col.startswith(prefix) case "sp_column_string_substr": col = self.decode_expr(expr.sp_column_string_substr.col) @@ -1106,7 +1141,7 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: case "sp_column_string_ends_with": col = self.decode_expr(expr.sp_column_string_ends_with.col) suffix = self.decode_expr(expr.sp_column_string_ends_with.suffix) - return col.ends_with(suffix) + return col.endswith(suffix) case "sp_column_string_collate": col = self.decode_expr(expr.sp_column_string_collate.col) @@ -1195,10 +1230,9 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: # DATAFRAME FUNCTIONS case "sp_create_dataframe": data = self.decode_dataframe_data_expr(expr.sp_create_dataframe.data) - d = MessageToDict(expr.sp_create_dataframe) schema = ( self.decode_dataframe_schema_expr(expr.sp_create_dataframe.schema) - if "schema" in d + if expr.sp_create_dataframe.HasField("schema") else None ) df = self.session.create_dataframe(data=data, schema=schema) @@ -1342,10 +1376,10 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: ) case "sp_dataframe_count": - df = self.decode_expr(expr.sp_dataframe_first.df) + df = self.symbol_table[expr.sp_dataframe_count.id.bitfield1][1] d = MessageToDict(expr.sp_dataframe_count) statement_params = self.get_statement_params(d) - block = d["block"] + block = d.get("block", False) return df.count( statement_params=statement_params, block=block, @@ -2311,7 +2345,206 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: case "sp_dataframe_write": df = self.decode_expr(expr.sp_dataframe_write.df) - return df.write + res = df.write + if expr.sp_dataframe_write.HasField("partition_by"): + partition_by = self.decode_expr( + expr.sp_dataframe_write.partition_by + ) + res = res.partition_by(partition_by) + if expr.sp_dataframe_write.HasField("save_mode"): + save_mode = self.decode_save_mode(expr.sp_dataframe_write.save_mode) + res = res.mode(save_mode) + options = self.decode_dsl_map_expr(expr.sp_dataframe_write.options) + if options: + res = res.options(**options) + return res + + case "sp_write_copy_into_location": + df = self.symbol_table[expr.sp_write_copy_into_location.id.bitfield1][1] + block = expr.sp_write_copy_into_location.block + copy_options = self.decode_dsl_map_expr( + expr.sp_write_copy_into_location.copy_options + ) + d = MessageToDict(expr.sp_write_copy_into_location) + file_format_name = d.get("fileFormatName", None) + file_format_type = d.get("fileFormatType", None) + format_type_options = ( + self.decode_dsl_map_expr( + expr.sp_write_copy_into_location.format_type_options + ) + if "formatTypeOptions" in d + else None + ) + header = expr.sp_write_copy_into_location.header + location = expr.sp_write_copy_into_location.location + partition_by = ( + self.decode_expr(expr.sp_write_copy_into_location.partition_by) + if expr.sp_write_copy_into_location.HasField("partition_by") + else None + ) + statement_params = self.decode_dsl_map_expr( + expr.sp_write_copy_into_location.statement_params + ) + return df.copy_into_location( + location, + partition_by=partition_by, + file_format_name=file_format_name, + file_format_type=file_format_type, + format_type_options=format_type_options, + header=header, + statement_params=statement_params, + block=block, + **copy_options, + ) + + case "sp_write_csv": + df = self.symbol_table[expr.sp_write_csv.id.bitfield1][1] + block = expr.sp_write_csv.block + copy_options = self.decode_dsl_map_expr(expr.sp_write_csv.copy_options) + format_type_options = self.decode_dsl_map_expr( + expr.sp_write_csv.format_type_options + ) + header = expr.sp_write_csv.header + location = expr.sp_write_csv.location + partition_by = ( + self.decode_expr(expr.sp_write_csv.partition_by) + if expr.sp_write_csv.HasField("partition_by") + else None + ) + statement_params = self.decode_dsl_map_expr( + expr.sp_write_csv.statement_params + ) + return df.csv( + location, + partition_by=partition_by, + format_type_options=format_type_options, + header=header, + statement_params=statement_params, + block=block, + **copy_options, + ) + + case "sp_write_json": + df = self.symbol_table[expr.sp_write_json.id.bitfield1][1] + block = expr.sp_write_json.block + copy_options = self.decode_dsl_map_expr(expr.sp_write_json.copy_options) + format_type_options = self.decode_dsl_map_expr( + expr.sp_write_json.format_type_options + ) + header = expr.sp_write_json.header + location = expr.sp_write_json.location + partition_by = ( + self.decode_expr(expr.sp_write_json.partition_by) + if expr.sp_write_json.HasField("partition_by") + else None + ) + statement_params = self.decode_dsl_map_expr( + expr.sp_write_json.statement_params + ) + return df.json( + location, + partition_by=partition_by, + format_type_options=format_type_options, + header=header, + statement_params=statement_params, + block=block, + **copy_options, + ) + + case "sp_write_parquet": + df = self.symbol_table[expr.sp_write_parquet.id.bitfield1][1] + block = expr.sp_write_parquet.block + copy_options = self.decode_dsl_map_expr( + expr.sp_write_parquet.copy_options + ) + format_type_options = self.decode_dsl_map_expr( + expr.sp_write_parquet.format_type_options + ) + header = expr.sp_write_parquet.header + location = expr.sp_write_parquet.location + partition_by = ( + self.decode_expr(expr.sp_write_parquet.partition_by) + if expr.sp_write_parquet.HasField("partition_by") + else None + ) + statement_params = self.decode_dsl_map_expr( + expr.sp_write_parquet.statement_params + ) + return df.parquet( + location, + partition_by=partition_by, + format_type_options=format_type_options, + header=header, + statement_params=statement_params, + block=block, + **copy_options, + ) + + case "sp_write_table": + df = self.symbol_table[expr.sp_write_table.id.bitfield1][1] + block = expr.sp_write_table.block + change_tracking = ( + expr.sp_write_table.change_tracking.value + if expr.sp_write_table.HasField("change_tracking") + else None + ) + clustering_keys = [ + self.decode_expr(ck) + for ck in expr.sp_write_table.clustering_keys.list + ] + column_order = expr.sp_write_table.column_order + comment = ( + expr.sp_write_table.comment.value + if expr.sp_write_table.HasField("comment") + else None + ) + copy_grants = expr.sp_write_table.copy_grants + create_temp_table = expr.sp_write_table.create_temp_table + data_retention_time = ( + expr.sp_write_table.data_retention_time.value + if expr.sp_write_table.HasField("data_retention_time") + else None + ) + enable_schema_evolution = ( + expr.sp_write_table.enable_schema_evolution.value + if expr.sp_write_table.HasField("enable_schema_evolution") + else None + ) + iceberg_config = self.decode_dsl_map_expr( + expr.sp_write_table.iceberg_config + ) + max_data_extension_time = ( + expr.sp_write_table.max_data_extension_time.value + if expr.sp_write_table.HasField("max_data_extension_time") + else None + ) + mode = ( + self.decode_save_mode(expr.sp_write_table.mode) + if expr.sp_write_table.HasField("mode") + else None + ) + statement_params = self.decode_dsl_map_expr( + expr.sp_write_table.statement_params + ) + table_name = self.decode_name_expr(expr.sp_write_table.table_name) + table_type = expr.sp_write_table.table_type + df.save_as_table( + table_name, + mode=mode, + column_order=column_order, + create_temp_table=create_temp_table, + table_type=table_type, + clustering_keys=clustering_keys, + statement_params=statement_params, + block=block, + comment=comment, + enable_schema_evolution=enable_schema_evolution, + data_retention_time=data_retention_time, + max_data_extension_time=max_data_extension_time, + change_tracking=change_tracking, + copy_grants=copy_grants, + iceberg_config=iceberg_config, + ) case "stored_procedure": input_types = [ @@ -2427,6 +2660,11 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: else: return Row(*values) + case "sp_sql": + params = [self.decode_expr(param) for param in expr.sp_sql.params] + query = expr.sp_sql.query + return self.session.sql(query, params) + case _: raise NotImplementedError( "Expression type not implemented yet: %s" From 7306b73e8af3dc89f065aae7b409c655cd062209 Mon Sep 17 00:00:00 2001 From: Varnika Budati Date: Fri, 24 Jan 2025 16:59:08 -0800 Subject: [PATCH 2/2] dataframe reader stuff --- tests/ast/decoder.py | 91 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 88 insertions(+), 3 deletions(-) diff --git a/tests/ast/decoder.py b/tests/ast/decoder.py index 31d68f8900..57c4d3ae6e 100644 --- a/tests/ast/decoder.py +++ b/tests/ast/decoder.py @@ -188,6 +188,59 @@ def decode_col_exprs(self, expr: proto.Expr) -> List[Column]: col_list = [self.decode_expr(arg) for arg in expr] return col_list + def decode_dataframe_reader_expr(self, df_reader_expr: proto.SpDataframeReader): + """ + Decode a dataframe reader expression to get the dataframe. + + Parameters + ---------- + df_reader_expr : proto.SpDataframeReader + The expression to decode. + + """ + match df_reader_expr.WhichOneof("variant"): + case "sp_dataframe_reader_init": + return self.session.read + case "sp_dataframe_reader_option": + reader = self.decode_dataframe_reader_expr( + df_reader_expr.sp_dataframe_reader_option.reader + ) + key = df_reader_expr.sp_dataframe_reader_option.key + value = self.decode_expr( + df_reader_expr.sp_dataframe_reader_option.value + ) + return reader.option(key, value) + case "sp_dataframe_reader_options": + reader = self.decode_dataframe_reader_expr( + df_reader_expr.sp_dataframe_reader_options.reader + ) + configs = self.decode_dsl_map_expr( + df_reader_expr.sp_dataframe_reader_options.configs + ) + return reader.options(configs) + case "sp_dataframe_reader_schema": + reader = self.decode_dataframe_reader_expr( + df_reader_expr.sp_dataframe_reader_schema.reader + ) + schema = self.decode_struct_type_expr( + df_reader_expr.sp_dataframe_reader_schema.schema + ) + return reader.schema(schema) + case "sp_dataframe_reader_with_metadata": + reader = self.decode_dataframe_reader_expr( + df_reader_expr.sp_dataframe_reader_with_metadata.reader + ) + metadata_columns = [ + self.decode_expr(arg) + for arg in df_reader_expr.sp_dataframe_reader_with_metadata.metadata_columns.args + ] + return reader.with_metadata(*metadata_columns) + case _: + raise ValueError( + "Unknown dataframe reader type: %s" + % df_reader_expr.WhichOneof("variant") + ) + def decode_dsl_map_expr(self, map_expr: Iterable) -> dict: """ Given a map expression, return the result as a Python dictionary. @@ -465,9 +518,8 @@ def decode_data_type_expr( return ShortType() case "sp_string_type": length = ( - data_type_expr.sp_string_type.length + data_type_expr.sp_string_type.length.value if data_type_expr.sp_string_type.HasField("length") - and isinstance(data_type_expr.sp_string_type.length, int) else None ) return StringType(length) @@ -487,7 +539,10 @@ def decode_data_type_expr( for field in data_type_expr.sp_struct_type.fields.list: column_identifier = field.column_identifier.name data_type = self.decode_data_type_expr(field.data_type) - fields.append(StructField(column_identifier, data_type)) + nullable = field.nullable + fields.append( + StructField(column_identifier, data_type, nullable) + ) else: fields = None structured = data_type_expr.sp_struct_type.structured @@ -2512,6 +2567,36 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: iceberg_config=iceberg_config, ) + case "sp_read_avro": + path = expr.sp_read_avro.path + reader = self.decode_dataframe_reader_expr(expr.sp_read_avro.reader) + return reader.avro(path) + + case "sp_read_csv": + path = expr.sp_read_csv.path + reader = self.decode_dataframe_reader_expr(expr.sp_read_csv.reader) + return reader.csv(path) + + case "sp_read_json": + path = expr.sp_read_json.path + reader = self.decode_dataframe_reader_expr(expr.sp_read_json.reader) + return reader.json(path) + + case "sp_read_orc": + path = expr.sp_read_orc.path + reader = self.decode_dataframe_reader_expr(expr.sp_read_orc.reader) + return reader.orc(path) + + case "sp_read_parquet": + path = expr.sp_read_parquet.path + reader = self.decode_dataframe_reader_expr(expr.sp_read_parquet.reader) + return reader.parquet(path) + + case "sp_read_xml": + path = expr.sp_read_xml.path + reader = self.decode_dataframe_reader_expr(expr.sp_read_xml.reader) + return reader.xml(path) + case "sp_dataframe_write": df = self.decode_expr(expr.sp_dataframe_write.df) res = df.write