Skip to content

Commit 7a22fd7

Browse files
Merge branch 'vbudati/SNOW-1794510-merge-decoder' into vbudati/SNOW-1794510-support-missing-apis
# Conflicts: # tests/ast/decoder.py
2 parents 856cc99 + 0391c5c commit 7a22fd7

File tree

1 file changed

+177
-8
lines changed

1 file changed

+177
-8
lines changed

tests/ast/decoder.py

Lines changed: 177 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010

1111
from pandas import DataFrame as PandasDataFrame
1212

13+
from snowflake.snowpark.table import WhenMatchedClause, WhenNotMatchedClause
1314
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SaveMode
1415
from snowflake.snowpark.window import WindowSpec, Window, WindowRelativePosition
1516
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
1617

1718
from google.protobuf.json_format import MessageToDict
1819

1920
from snowflake.snowpark.relational_grouped_dataframe import GroupingSets
20-
from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row
21+
from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row, Table
2122
import snowflake.snowpark.functions
2223
from snowflake.snowpark.functions import (
2324
udaf,
@@ -26,6 +27,8 @@
2627
when,
2728
sproc,
2829
call_table_function,
30+
when_matched,
31+
when_not_matched,
2932
)
3033
from snowflake.snowpark.types import (
3134
DataType,
@@ -432,9 +435,10 @@ def decode_data_type_expr(
432435
if isinstance(
433436
data_type_expr.sp_pandas_data_frame_type.col_types, Iterable
434437
):
435-
col_types = []
436-
for col_type in data_type_expr.sp_pandas_data_frame_type.col_types:
437-
col_types.append(self.decode_data_type_expr(col_type))
438+
col_types = [
439+
self.decode_data_type_expr(col_type)
440+
for col_type in data_type_expr.sp_pandas_data_frame_type.col_types
441+
]
438442
else:
439443
col_types = [data_type_expr.sp_pandas_data_frame_type.col_types]
440444
if isinstance(
@@ -566,6 +570,74 @@ def decode_join_type(self, join_type: proto.SpJoinType) -> str:
566570
"Unknown join type: %s" % join_type.WhichOneof("variant")
567571
)
568572

573+
def decode_matched_clause(
574+
self, matched_clause: proto.SpMatchedClause
575+
) -> Union[WhenMatchedClause, WhenNotMatchedClause]:
576+
"""
577+
Decode a matched clause expression to get the clause.
578+
579+
Parameters
580+
----------
581+
matched_clause : proto.SpMatchedClause
582+
The expression to decode.
583+
584+
Returns
585+
-------
586+
WhenMatchedClause or WhenNotMatchedClause
587+
The decoded clause.
588+
"""
589+
match matched_clause.WhichOneof("variant"):
590+
case "sp_merge_delete_when_matched_clause":
591+
condition = (
592+
self.decode_expr(
593+
matched_clause.sp_merge_delete_when_matched_clause.condition
594+
)
595+
if matched_clause.sp_merge_delete_when_matched_clause.HasField(
596+
"condition"
597+
)
598+
else None
599+
)
600+
return when_matched(condition).delete()
601+
case "sp_merge_insert_when_not_matched_clause":
602+
condition = (
603+
self.decode_expr(
604+
matched_clause.sp_merge_insert_when_not_matched_clause.condition
605+
)
606+
if matched_clause.sp_merge_insert_when_not_matched_clause.HasField(
607+
"condition"
608+
)
609+
else None
610+
)
611+
insert_keys = [
612+
self.decode_expr(key)
613+
for key in matched_clause.sp_merge_insert_when_not_matched_clause.insert_keys.list
614+
]
615+
insert_values = [
616+
self.decode_expr(value)
617+
for value in matched_clause.sp_merge_insert_when_not_matched_clause.insert_values.list
618+
]
619+
return when_not_matched(condition).insert(
620+
dict(zip(insert_keys, insert_values))
621+
)
622+
case "sp_merge_update_when_matched_clause":
623+
condition = (
624+
self.decode_expr(
625+
matched_clause.sp_merge_update_when_matched_clause.condition
626+
)
627+
if matched_clause.sp_merge_update_when_matched_clause.HasField(
628+
"condition"
629+
)
630+
else None
631+
)
632+
update_assignments = self.decode_dsl_map_expr(
633+
matched_clause.sp_merge_update_when_matched_clause.update_assignments.list
634+
)
635+
return when_matched(condition).update(update_assignments)
636+
case _:
637+
raise ValueError(
638+
"Unknown matched clause: %s" % matched_clause.WhichOneof("variant")
639+
)
640+
569641
def decode_pivot_value_expr(self, pivot_value_expr: proto.SpPivotValue) -> Any:
570642
"""
571643
Decode expr to get the pivot value.
@@ -1918,11 +1990,108 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
19181990
case "sp_table":
19191991
assert expr.sp_table.HasField("name")
19201992
table_name = self.decode_name_expr(expr.sp_table.name)
1921-
return self.session.table(
1922-
table_name,
1923-
is_temp_table_for_cleanup=expr.sp_table.is_temp_table_for_cleanup
1993+
is_temp_table = (
1994+
expr.sp_table.is_temp_table_for_cleanup
19241995
if hasattr(expr.sp_table, "is_temp_table_for_cleanup")
1925-
else False,
1996+
else False
1997+
)
1998+
match expr.sp_table.variant.WhichOneof("variant"):
1999+
case "sp_session_table":
2000+
return self.session.table(
2001+
table_name,
2002+
is_temp_table_for_cleanup=is_temp_table,
2003+
)
2004+
case "sp_table_init":
2005+
return Table(table_name, self.session, is_temp_table)
2006+
case _:
2007+
raise ValueError(
2008+
"Unknown table type: %s"
2009+
% expr.sp_table.WhichOneof("variant")
2010+
)
2011+
2012+
case "sp_table_delete":
2013+
table = self.symbol_table[expr.sp_table_delete.id.bitfield1][1]
2014+
block = expr.sp_table_delete.block
2015+
condition = (
2016+
self.decode_expr(expr.sp_table_delete.condition)
2017+
if expr.sp_table_delete.HasField("condition")
2018+
else None
2019+
)
2020+
source = (
2021+
self.decode_expr(expr.sp_table_delete.source)
2022+
if expr.sp_table_delete.HasField("source")
2023+
else None
2024+
)
2025+
statement_params = self.get_statement_params(
2026+
MessageToDict(expr.sp_table_delete)
2027+
)
2028+
return table.delete(
2029+
condition=condition,
2030+
source=source,
2031+
statement_params=statement_params,
2032+
block=block,
2033+
)
2034+
2035+
case "sp_table_drop_table":
2036+
table = self.symbol_table[expr.sp_table_drop_table.id.bitfield1][1]
2037+
return table.drop_table()
2038+
2039+
case "sp_table_merge":
2040+
table = self.symbol_table[expr.sp_table_merge.id.bitfield1][1]
2041+
block = expr.sp_table_merge.block
2042+
clauses = [
2043+
self.decode_matched_clause(clause)
2044+
for clause in expr.sp_table_merge.clauses
2045+
]
2046+
join_expr = self.decode_expr(expr.sp_table_merge.join_expr)
2047+
source = self.decode_expr(expr.sp_table_merge.source)
2048+
statement_params = self.get_statement_params(
2049+
MessageToDict(expr.sp_table_merge)
2050+
)
2051+
return table.merge(
2052+
source=source,
2053+
join_expr=join_expr,
2054+
clauses=clauses,
2055+
statement_params=statement_params,
2056+
block=block,
2057+
)
2058+
2059+
case "sp_table_sample":
2060+
df = self.decode_expr(expr.sp_table_sample.df)
2061+
num = expr.sp_table_sample.num.value
2062+
probability_fraction = expr.sp_table_sample.probability_fraction.value
2063+
sampling_method = expr.sp_table_sample.sampling_method.value
2064+
seed = expr.sp_table_sample.seed.value
2065+
return df.sample(
2066+
frac=probability_fraction,
2067+
n=num,
2068+
seed=seed,
2069+
sampling_method=sampling_method,
2070+
)
2071+
2072+
case "sp_table_update":
2073+
table = self.symbol_table[expr.sp_table_update.id.bitfield1][1]
2074+
assignments = self.decode_dsl_map_expr(expr.sp_table_update.assignments)
2075+
block = expr.sp_table_update.block
2076+
condition = (
2077+
self.decode_expr(expr.sp_table_update.condition)
2078+
if expr.sp_table_update.HasField("condition")
2079+
else None
2080+
)
2081+
source = (
2082+
self.decode_expr(expr.sp_table_update.source)
2083+
if expr.sp_table_update.HasField("source")
2084+
else None
2085+
)
2086+
statement_params = self.get_statement_params(
2087+
MessageToDict(expr.sp_table_update)
2088+
)
2089+
return table.update(
2090+
assignments,
2091+
condition,
2092+
source,
2093+
statement_params=statement_params,
2094+
block=block,
19262095
)
19272096

19282097
case "sp_to_snowpark_pandas":

0 commit comments

Comments
 (0)