Skip to content

Commit 63bc2dd

Browse files
committed
moar changes
1 parent 0a8acae commit 63bc2dd

File tree

3 files changed

+40
-36
lines changed

3 files changed

+40
-36
lines changed

src/snowflake/snowpark/_internal/ast/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,10 +1490,14 @@ def clear_symbols_and_udfs(message: proto.Request) -> None:
14901490
for stmt in message.body:
14911491
if hasattr(stmt, "assign"):
14921492
stmt.assign.ClearField("symbol")
1493-
if str(stmt.assign.expr.udf):
1494-
stmt.assign.expr.ClearField("udf")
1495-
if str(stmt.assign.expr.stored_procedure):
1496-
stmt.assign.expr.ClearField("stored_procedure")
1493+
if str(stmt.assign.expr.udf):
1494+
stmt.assign.expr.ClearField("udf")
1495+
if str(stmt.assign.expr.stored_procedure):
1496+
stmt.assign.ClearField("expr")
1497+
if str(stmt.assign.expr.apply_expr):
1498+
if str(stmt.assign.expr.apply_expr.fn):
1499+
if str(stmt.assign.expr.apply_expr.fn.stored_procedure):
1500+
stmt.assign.ClearField("expr")
14971501

14981502

14991503
def base64_str_to_request(base64_str: str) -> proto.Request:

src/snowflake/snowpark/_internal/proto/ast.proto

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ message List_SpDataType {
2222
repeated SpDataType list = 1;
2323
}
2424

25-
message List_SpStructField {
26-
repeated SpStructField list = 1;
27-
}
28-
2925
message List_String {
3026
repeated string list = 1;
3127
}
@@ -178,7 +174,7 @@ message SpStructField {
178174

179175
// sp-type.ir:46
180176
message SpStructType {
181-
List_SpStructField fields = 1;
177+
repeated SpStructField fields = 1;
182178
bool structured = 2;
183179
}
184180

tests/ast/decoder.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -551,8 +551,6 @@ def get_statement_params(self, d: Dict):
551551
return statement_params
552552

553553
def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
554-
if not hasattr(expr, "WhichOneof"):
555-
breakpoint()
556554
match expr.WhichOneof("variant"):
557555
# COLUMN BINARY OPERATIONS
558556
case "add":
@@ -614,14 +612,19 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
614612
fn_name, *pos_args, **named_args
615613
)
616614
case "stored_procedure":
617-
return self.session.call(fn_name, *pos_args, **named_args)
615+
self.session._ast_batch.assign()
616+
return None
617+
# return self.session.call(fn_name, *pos_args, **named_args, _emit_ast=False)
618618

619619
case _:
620620
raise ValueError(
621621
"Unknown function reference type: %s"
622622
% expr.apply_expr.fn.WhichOneof("variant")
623623
)
624624

625+
if isinstance(fn, snowflake.snowpark.stored_procedure.StoredProcedure):
626+
return None
627+
625628
result = fn(*pos_args, **named_args)
626629
if hasattr(expr, "var_id"):
627630
self.symbol_table[expr.var_id.bitfield1] = (
@@ -1646,7 +1649,6 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
16461649
return self.decode_expr(expr.sp_session_table_function.fn, **kwargs)
16471650

16481651
case "sp_table":
1649-
breakpoint()
16501652
assert expr.sp_table.HasField("name")
16511653
table_name = self.decode_name_expr(expr.sp_table.name)
16521654
is_temp_table_for_cleanup = expr.sp_table.is_temp_table_for_cleanup
@@ -1957,29 +1959,31 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
19571959
return df.write
19581960

19591961
case "stored_procedure":
1960-
input_types = [
1961-
self.decode_data_type_expr(input_type)
1962-
for input_type in expr.stored_procedure.input_types.list
1963-
]
1964-
execute_as = expr.stored_procedure.execute_as
1965-
comment = expr.stored_procedure.comment.value
1962+
self.session._ast_batch.assign()
19661963
registered_object_name = self.decode_name_expr(
19671964
expr.stored_procedure.func.object_name
19681965
)
1969-
return_type = self.decode_data_type_expr(
1970-
expr.stored_procedure.return_type
1971-
)
1972-
name = self.decode_name_expr(expr.stored_procedure.name)
1973-
ret_sproc = sproc(
1974-
self.session.sproc._registry[registered_object_name],
1975-
name=name,
1976-
return_type=return_type,
1977-
input_types=input_types,
1978-
execute_as=execute_as,
1979-
comment=comment,
1980-
_registered_object_name=registered_object_name,
1981-
)
1982-
return ret_sproc
1966+
return self.session.sproc._registry[registered_object_name]
1967+
"""input_types = [
1968+
self.decode_data_type_expr(input_type)
1969+
for input_type in expr.stored_procedure.input_types.list
1970+
]
1971+
execute_as = expr.stored_procedure.execute_as
1972+
comment = expr.stored_procedure.comment.value
1973+
return_type = self.decode_data_type_expr(
1974+
expr.stored_procedure.return_type
1975+
)
1976+
name = self.decode_name_expr(expr.stored_procedure.name)
1977+
ret_sproc = sproc(
1978+
self.session.sproc._registry[registered_object_name],
1979+
name=name,
1980+
return_type=return_type,
1981+
input_types=input_types,
1982+
execute_as=execute_as,
1983+
comment=comment,
1984+
_registered_object_name=registered_object_name,
1985+
)
1986+
return ret_sproc"""
19831987

19841988
case "sp_flatten":
19851989
input = self.decode_expr(expr.sp_flatten.input)
@@ -2023,7 +2027,7 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
20232027
)
20242028

20252029
case "sp_sql":
2026-
params = [self.decode_expr(param) for parm in expr.sp_sql.params]
2030+
params = [self.decode_expr(param) for param in expr.sp_sql.params]
20272031
query = expr.sp_sql.query
20282032
return self.session.sql(query=query, params=params)
20292033

@@ -2045,8 +2049,8 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
20452049
case "sp_save_mode_error_if_exists":
20462050
mode = "error_if_exists"
20472051

2048-
case "sp_save_mode_truncate":
2049-
mode = "truncate"
2052+
case "sp_save_mode_ignore":
2053+
mode = "ignore"
20502054

20512055
case "_":
20522056
mode = None

0 commit comments

Comments
 (0)