Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/snowflake/snowpark/_internal/ast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,10 +1490,14 @@ def clear_symbols_and_udfs(message: proto.Request) -> None:
for stmt in message.body:
if hasattr(stmt, "assign"):
stmt.assign.ClearField("symbol")
if str(stmt.assign.expr.udf):
stmt.assign.expr.ClearField("udf")
if str(stmt.assign.expr.stored_procedure):
stmt.assign.expr.ClearField("stored_procedure")
if str(stmt.assign.expr.udf):
stmt.assign.expr.ClearField("udf")
if str(stmt.assign.expr.stored_procedure):
stmt.assign.ClearField("expr")
if str(stmt.assign.expr.apply_expr):
if str(stmt.assign.expr.apply_expr.fn):
if str(stmt.assign.expr.apply_expr.fn.stored_procedure):
stmt.assign.ClearField("expr")


def base64_str_to_request(base64_str: str) -> proto.Request:
Expand Down
6 changes: 1 addition & 5 deletions src/snowflake/snowpark/_internal/proto/ast.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ message List_SpDataType {
repeated SpDataType list = 1;
}

message List_SpStructField {
repeated SpStructField list = 1;
}

message List_String {
repeated string list = 1;
}
Expand Down Expand Up @@ -178,7 +174,7 @@ message SpStructField {

// sp-type.ir:46
message SpStructType {
List_SpStructField fields = 1;
repeated SpStructField fields = 1;
bool structured = 2;
}

Expand Down
174 changes: 143 additions & 31 deletions tests/ast/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,26 +166,26 @@ def convert_name_to_list(self, name: any) -> List:
return [name]
return [qualified_name for qualified_name in name]

def decode_name_expr(self, table_name: proto.SpName) -> str:
def decode_name_expr(self, name: proto.SpName) -> str:
"""
Decode a table name expression to get the table name.
Decode a name expression to get the name.

Parameters
----------
table_name : proto.SpTableName
name : proto.SpTableName
The table name to decode.

Returns
-------
str
The decoded table name.
"""
if table_name.name.HasField("sp_name_flat"):
return table_name.name.sp_name_flat.name
elif table_name.name.HasField("sp_name_structured"):
return table_name.name.sp_name_structured.name
if name.name.HasField("sp_name_flat"):
return name.name.sp_name_flat.name
elif name.name.HasField("sp_name_structured"):
return name.name.sp_name_structured.name
else:
raise ValueError("Table name not found in proto.SpTableName")
return None

def decode_fn_ref_expr(self, fn_ref_expr: proto.FnRefExpr) -> str:
"""
Expand Down Expand Up @@ -325,6 +325,12 @@ def decode_data_type_expr(
DataType, StructField, or ColumnIdentifier
The decoded data type.
"""
if hasattr(data_type_expr, "data_type"):
column_identifier = data_type_expr.column_identifier.name
data_type = self.decode_data_type_expr(data_type_expr.data_type)
nullable = data_type_expr.nullable
return StructField(column_identifier, data_type, nullable)

match data_type_expr.WhichOneof("variant"):
case "sp_array_type":
structured = data_type_expr.sp_array_type.structured
Expand Down Expand Up @@ -605,12 +611,20 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
return call_table_function(
fn_name, *pos_args, **named_args
)
case "stored_procedure":
self.session._ast_batch.assign()
return None
# return self.session.call(fn_name, *pos_args, **named_args, _emit_ast=False)

case _:
raise ValueError(
"Unknown function reference type: %s"
% expr.apply_expr.fn.WhichOneof("variant")
)

if isinstance(fn, snowflake.snowpark.stored_procedure.StoredProcedure):
return None

result = fn(*pos_args, **named_args)
if hasattr(expr, "var_id"):
self.symbol_table[expr.var_id.bitfield1] = (
Expand Down Expand Up @@ -1112,7 +1126,7 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
)

case "sp_dataframe_count":
df = self.decode_expr(expr.sp_dataframe_first.df)
df = self.symbol_table[expr.sp_dataframe_count.id.bitfield1][1]
d = MessageToDict(expr.sp_dataframe_count)
statement_params = self.get_statement_params(d)
block = d["block"]
Expand Down Expand Up @@ -1412,9 +1426,13 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
return df.select_expr(exprs)

case "sp_dataframe_show":
df = self.decode_expr(
self.symbol_table[expr.sp_dataframe_show.id.bitfield1][1]
)
df_show = self.symbol_table[expr.sp_dataframe_show.id.bitfield1][1]
if isinstance(df_show, snowflake.snowpark.dataframe.DataFrame):
df = df_show
else:
df = self.decode_expr(
self.symbol_table[expr.sp_dataframe_show.id.bitfield1][1]
)
return df.show()

case "sp_dataframe_sort":
Expand Down Expand Up @@ -1633,7 +1651,10 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
case "sp_table":
assert expr.sp_table.HasField("name")
table_name = self.decode_name_expr(expr.sp_table.name)
return self.session.table(table_name)
is_temp_table_for_cleanup = expr.sp_table.is_temp_table_for_cleanup
return self.session.table(
table_name, is_temp_table_for_cleanup=is_temp_table_for_cleanup
)

case "sp_to_snowpark_pandas":
df = self.decode_expr(expr.sp_to_snowpark_pandas.df)
Expand Down Expand Up @@ -1938,27 +1959,31 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
return df.write

case "stored_procedure":
input_types = [
self.decode_data_type_expr(input_type)
for input_type in expr.stored_procedure.input_types.list
]
execute_as = expr.stored_procedure.execute_as
comment = expr.stored_procedure.comment.value
self.session._ast_batch.assign()
registered_object_name = self.decode_name_expr(
expr.stored_procedure.func.object_name
)
return_type = self.decode_data_type_expr(
expr.stored_procedure.return_type
)
ret_sproc = sproc(
lambda *args: None,
return_type=return_type,
input_types=input_types,
execute_as=execute_as,
comment=comment,
_registered_object_name=registered_object_name,
)
return ret_sproc
return self.session.sproc._registry[registered_object_name]
"""input_types = [
self.decode_data_type_expr(input_type)
for input_type in expr.stored_procedure.input_types.list
]
execute_as = expr.stored_procedure.execute_as
comment = expr.stored_procedure.comment.value
return_type = self.decode_data_type_expr(
expr.stored_procedure.return_type
)
name = self.decode_name_expr(expr.stored_procedure.name)
ret_sproc = sproc(
self.session.sproc._registry[registered_object_name],
name=name,
return_type=return_type,
input_types=input_types,
execute_as=execute_as,
comment=comment,
_registered_object_name=registered_object_name,
)
return ret_sproc"""

case "sp_flatten":
input = self.decode_expr(expr.sp_flatten.input)
Expand Down Expand Up @@ -2001,7 +2026,94 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
columns, rowcount=row_count, timelimit=time_limit_seconds
)

case "sp_sql":
params = [self.decode_expr(param) for param in expr.sp_sql.params]
query = expr.sp_sql.query
return self.session.sql(query=query, params=params)

case "sp_write_table":
table_name = self.decode_name_expr(expr.sp_write_table.table_name)

mode = None

match expr.sp_write_table.mode.WhichOneof("variant"):
case "sp_save_mode_overwrite":
mode = "overwrite"

case "sp_save_mode_append":
mode = "append"

case "sp_save_mode_truncate":
mode = "truncate"

case "sp_save_mode_error_if_exists":
mode = "error_if_exists"

case "sp_save_mode_ignore":
mode = "ignore"

case "_":
mode = None

table_type = expr.sp_write_table.table_type
column_order = expr.sp_write_table.column_order
create_temp_table = expr.sp_write_table.create_temp_table
clustering_keys = [
self.decode_expr(col)
for col in expr.sp_write_table.clustering_keys.list
]
d = MessageToDict(expr.sp_write_table)
statement_params = self.get_statement_params(d)
block = expr.sp_write_table.block
comment = None
if "comment" in d:
comment = expr.sp_write_table.comment.value
enable_schema_evolution = None
if "enable_schema_evolution" in d:
enable_schema_evolution = (
expr.sp_write_table.enable_schema_evolution.value
)
data_retention_time = None
if "data_retention_time" in d:
data_retention_time = expr.sp_write_table.data_retention_time.value
max_data_extension_time = None
if "max_data_extension_time" in d:
max_data_extension_time = (
expr.sp_write_table.max_data_extension_time.value
)
change_tracking = None
if "change_tracking" in d:
change_tracking = expr.sp_write_table.change_tracking.value
copy_grants = expr.sp_write_table.copy_grants
iceberg_config = None
if hasattr(expr.sp_write_table, "iceberg_config"):
iceberg_config = {
expr.sp_write_table.iceberg_config[i]
._1: expr.sp_write_table.iceberg_config[i]
._2
for i in range(len(expr.sp_write_table.iceberg_config))
}
df_writer = self.symbol_table[expr.sp_write_table.id.bitfield1][1]
return df_writer.save_as_table(
table_name=table_name,
mode=mode,
column_order=column_order,
create_temp_table=create_temp_table,
table_type=table_type,
clustering_keys=clustering_keys,
statement_params=statement_params,
block=block,
comment=comment,
enable_schema_evolution=enable_schema_evolution,
data_retention_time=data_retention_time,
max_data_extension_time=max_data_extension_time,
change_tracking=change_tracking,
copy_grants=copy_grants,
iceberg_config=iceberg_config,
)

case _:
breakpoint()
raise NotImplementedError(
"Expression type not implemented yet: %s"
% expr.WhichOneof("variant")
Expand Down
Loading