Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions tests/ast/data/Table.merge.test
Original file line number Diff line number Diff line change
Expand Up @@ -165,25 +165,27 @@ body {
sp_dataframe_schema__struct {
v {
fields {
column_identifier {
name: "num"
}
data_type {
sp_integer_type: true
}
nullable: true
}
fields {
column_identifier {
name: "str"
list {
column_identifier {
name: "num"
}
data_type {
sp_integer_type: true
}
nullable: true
}
data_type {
sp_string_type {
length {
list {
column_identifier {
name: "str"
}
data_type {
sp_string_type {
length {
}
}
}
nullable: true
}
nullable: true
}
}
}
Expand Down
187 changes: 178 additions & 9 deletions tests/ast/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@

from pandas import DataFrame as PandasDataFrame

from snowflake.snowpark.window import WindowSpec, Window, WindowRelativePosition
from snowflake.snowpark.table import WhenMatchedClause, WhenNotMatchedClause
from snowflake.snowpark.window import Window, WindowRelativePosition
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto

from google.protobuf.json_format import MessageToDict

from snowflake.snowpark.relational_grouped_dataframe import GroupingSets
from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row
from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row, Table
import snowflake.snowpark.functions
from snowflake.snowpark.functions import (
udaf,
Expand All @@ -25,6 +26,8 @@
when,
sproc,
call_table_function,
when_matched,
when_not_matched,
)
from snowflake.snowpark.types import (
DataType,
Expand Down Expand Up @@ -431,9 +434,10 @@ def decode_data_type_expr(
if isinstance(
data_type_expr.sp_pandas_data_frame_type.col_types, Iterable
):
col_types = []
for col_type in data_type_expr.sp_pandas_data_frame_type.col_types:
col_types.append(self.decode_data_type_expr(col_type))
col_types = [
self.decode_data_type_expr(col_type)
for col_type in data_type_expr.sp_pandas_data_frame_type.col_types
]
else:
col_types = [data_type_expr.sp_pandas_data_frame_type.col_types]
if isinstance(
Expand Down Expand Up @@ -565,6 +569,74 @@ def decode_join_type(self, join_type: proto.SpJoinType) -> str:
"Unknown join type: %s" % join_type.WhichOneof("variant")
)

def decode_matched_clause(
self, matched_clause: proto.SpMatchedClause
) -> Union[WhenMatchedClause, WhenNotMatchedClause]:
"""
Decode a matched clause expression to get the clause.

Parameters
----------
matched_clause : proto.SpMatchedClause
The expression to decode.

Returns
-------
WhenMatchedClause or WhenNotMatchedClause
The decoded clause.
"""
match matched_clause.WhichOneof("variant"):
case "sp_merge_delete_when_matched_clause":
condition = (
self.decode_expr(
matched_clause.sp_merge_delete_when_matched_clause.condition
)
if matched_clause.sp_merge_delete_when_matched_clause.HasField(
"condition"
)
else None
)
return when_matched(condition).delete()
case "sp_merge_insert_when_not_matched_clause":
condition = (
self.decode_expr(
matched_clause.sp_merge_insert_when_not_matched_clause.condition
)
if matched_clause.sp_merge_insert_when_not_matched_clause.HasField(
"condition"
)
else None
)
insert_keys = [
self.decode_expr(key)
for key in matched_clause.sp_merge_insert_when_not_matched_clause.insert_keys.list
]
insert_values = [
self.decode_expr(value)
for value in matched_clause.sp_merge_insert_when_not_matched_clause.insert_values.list
]
return when_not_matched(condition).insert(
dict(zip(insert_keys, insert_values))
)
case "sp_merge_update_when_matched_clause":
condition = (
self.decode_expr(
matched_clause.sp_merge_update_when_matched_clause.condition
)
if matched_clause.sp_merge_update_when_matched_clause.HasField(
"condition"
)
else None
)
update_assignments = self.decode_dsl_map_expr(
matched_clause.sp_merge_update_when_matched_clause.update_assignments.list
)
return when_matched(condition).update(update_assignments)
case _:
raise ValueError(
"Unknown matched clause: %s" % matched_clause.WhichOneof("variant")
)

def decode_pivot_value_expr(self, pivot_value_expr: proto.SpPivotValue) -> Any:
"""
Decode expr to get the pivot value.
Expand Down Expand Up @@ -1884,11 +1956,108 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
case "sp_table":
assert expr.sp_table.HasField("name")
table_name = self.decode_name_expr(expr.sp_table.name)
return self.session.table(
table_name,
is_temp_table_for_cleanup=expr.sp_table.is_temp_table_for_cleanup
is_temp_table = (
expr.sp_table.is_temp_table_for_cleanup
if hasattr(expr.sp_table, "is_temp_table_for_cleanup")
else False,
else False
)
match expr.sp_table.variant.WhichOneof("variant"):
case "sp_session_table":
return self.session.table(
table_name,
is_temp_table_for_cleanup=is_temp_table,
)
case "sp_table_init":
return Table(table_name, self.session, is_temp_table)
case _:
raise ValueError(
"Unknown table type: %s"
% expr.sp_table.WhichOneof("variant")
)

case "sp_table_delete":
table = self.symbol_table[expr.sp_table_delete.id.bitfield1][1]
block = expr.sp_table_delete.block
condition = (
self.decode_expr(expr.sp_table_delete.condition)
if expr.sp_table_delete.HasField("condition")
else None
)
source = (
self.decode_expr(expr.sp_table_delete.source)
if expr.sp_table_delete.HasField("source")
else None
)
statement_params = self.get_statement_params(
MessageToDict(expr.sp_table_delete)
)
return table.delete(
condition=condition,
source=source,
statement_params=statement_params,
block=block,
)

case "sp_table_drop_table":
table = self.symbol_table[expr.sp_table_drop_table.id.bitfield1][1]
return table.drop_table()

case "sp_table_merge":
table = self.symbol_table[expr.sp_table_merge.id.bitfield1][1]
block = expr.sp_table_merge.block
clauses = [
self.decode_matched_clause(clause)
for clause in expr.sp_table_merge.clauses
]
join_expr = self.decode_expr(expr.sp_table_merge.join_expr)
source = self.decode_expr(expr.sp_table_merge.source)
statement_params = self.get_statement_params(
MessageToDict(expr.sp_table_merge)
)
return table.merge(
source=source,
join_expr=join_expr,
clauses=clauses,
statement_params=statement_params,
block=block,
)

case "sp_table_sample":
df = self.decode_expr(expr.sp_table_sample.df)
num = expr.sp_table_sample.num.value
probability_fraction = expr.sp_table_sample.probability_fraction.value
sampling_method = expr.sp_table_sample.sampling_method.value
seed = expr.sp_table_sample.seed.value
return df.sample(
frac=probability_fraction,
n=num,
seed=seed,
sampling_method=sampling_method,
)

case "sp_table_update":
table = self.symbol_table[expr.sp_table_update.id.bitfield1][1]
assignments = self.decode_dsl_map_expr(expr.sp_table_update.assignments)
block = expr.sp_table_update.block
condition = (
self.decode_expr(expr.sp_table_update.condition)
if expr.sp_table_update.HasField("condition")
else None
)
source = (
self.decode_expr(expr.sp_table_update.source)
if expr.sp_table_update.HasField("source")
else None
)
statement_params = self.get_statement_params(
MessageToDict(expr.sp_table_update)
)
return table.update(
assignments,
condition,
source,
statement_params=statement_params,
block=block,
)

case "sp_to_snowpark_pandas":
Expand Down
Loading