Skip to content

Commit 0391c5c

Browse files
SNOW-1830580 Decoder logic for tables (#2907)
1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1830580 2. Fill out the following pre-review checklist: - [ ] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [x] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development) 3. Please describe how your code solves the related issue. Added decoder logic for tests starting with "Tables.*", fixed Table.merge test issues.
1 parent 636c2b8 commit 0391c5c

File tree

2 files changed

+195
-24
lines changed

2 files changed

+195
-24
lines changed

tests/ast/data/Table.merge.test

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -165,25 +165,27 @@ body {
165165
sp_dataframe_schema__struct {
166166
v {
167167
fields {
168-
column_identifier {
169-
name: "num"
170-
}
171-
data_type {
172-
sp_integer_type: true
173-
}
174-
nullable: true
175-
}
176-
fields {
177-
column_identifier {
178-
name: "str"
168+
list {
169+
column_identifier {
170+
name: "num"
171+
}
172+
data_type {
173+
sp_integer_type: true
174+
}
175+
nullable: true
179176
}
180-
data_type {
181-
sp_string_type {
182-
length {
177+
list {
178+
column_identifier {
179+
name: "str"
180+
}
181+
data_type {
182+
sp_string_type {
183+
length {
184+
}
183185
}
184186
}
187+
nullable: true
185188
}
186-
nullable: true
187189
}
188190
}
189191
}

tests/ast/decoder.py

Lines changed: 178 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010

1111
from pandas import DataFrame as PandasDataFrame
1212

13-
from snowflake.snowpark.window import WindowSpec, Window, WindowRelativePosition
13+
from snowflake.snowpark.table import WhenMatchedClause, WhenNotMatchedClause
14+
from snowflake.snowpark.window import Window, WindowRelativePosition
1415
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
1516

1617
from google.protobuf.json_format import MessageToDict
1718

1819
from snowflake.snowpark.relational_grouped_dataframe import GroupingSets
19-
from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row
20+
from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row, Table
2021
import snowflake.snowpark.functions
2122
from snowflake.snowpark.functions import (
2223
udaf,
@@ -25,6 +26,8 @@
2526
when,
2627
sproc,
2728
call_table_function,
29+
when_matched,
30+
when_not_matched,
2831
)
2932
from snowflake.snowpark.types import (
3033
DataType,
@@ -431,9 +434,10 @@ def decode_data_type_expr(
431434
if isinstance(
432435
data_type_expr.sp_pandas_data_frame_type.col_types, Iterable
433436
):
434-
col_types = []
435-
for col_type in data_type_expr.sp_pandas_data_frame_type.col_types:
436-
col_types.append(self.decode_data_type_expr(col_type))
437+
col_types = [
438+
self.decode_data_type_expr(col_type)
439+
for col_type in data_type_expr.sp_pandas_data_frame_type.col_types
440+
]
437441
else:
438442
col_types = [data_type_expr.sp_pandas_data_frame_type.col_types]
439443
if isinstance(
@@ -565,6 +569,74 @@ def decode_join_type(self, join_type: proto.SpJoinType) -> str:
565569
"Unknown join type: %s" % join_type.WhichOneof("variant")
566570
)
567571

572+
def decode_matched_clause(
573+
self, matched_clause: proto.SpMatchedClause
574+
) -> Union[WhenMatchedClause, WhenNotMatchedClause]:
575+
"""
576+
Decode a matched clause expression to get the clause.
577+
578+
Parameters
579+
----------
580+
matched_clause : proto.SpMatchedClause
581+
The expression to decode.
582+
583+
Returns
584+
-------
585+
WhenMatchedClause or WhenNotMatchedClause
586+
The decoded clause.
587+
"""
588+
match matched_clause.WhichOneof("variant"):
589+
case "sp_merge_delete_when_matched_clause":
590+
condition = (
591+
self.decode_expr(
592+
matched_clause.sp_merge_delete_when_matched_clause.condition
593+
)
594+
if matched_clause.sp_merge_delete_when_matched_clause.HasField(
595+
"condition"
596+
)
597+
else None
598+
)
599+
return when_matched(condition).delete()
600+
case "sp_merge_insert_when_not_matched_clause":
601+
condition = (
602+
self.decode_expr(
603+
matched_clause.sp_merge_insert_when_not_matched_clause.condition
604+
)
605+
if matched_clause.sp_merge_insert_when_not_matched_clause.HasField(
606+
"condition"
607+
)
608+
else None
609+
)
610+
insert_keys = [
611+
self.decode_expr(key)
612+
for key in matched_clause.sp_merge_insert_when_not_matched_clause.insert_keys.list
613+
]
614+
insert_values = [
615+
self.decode_expr(value)
616+
for value in matched_clause.sp_merge_insert_when_not_matched_clause.insert_values.list
617+
]
618+
return when_not_matched(condition).insert(
619+
dict(zip(insert_keys, insert_values))
620+
)
621+
case "sp_merge_update_when_matched_clause":
622+
condition = (
623+
self.decode_expr(
624+
matched_clause.sp_merge_update_when_matched_clause.condition
625+
)
626+
if matched_clause.sp_merge_update_when_matched_clause.HasField(
627+
"condition"
628+
)
629+
else None
630+
)
631+
update_assignments = self.decode_dsl_map_expr(
632+
matched_clause.sp_merge_update_when_matched_clause.update_assignments.list
633+
)
634+
return when_matched(condition).update(update_assignments)
635+
case _:
636+
raise ValueError(
637+
"Unknown matched clause: %s" % matched_clause.WhichOneof("variant")
638+
)
639+
568640
def decode_pivot_value_expr(self, pivot_value_expr: proto.SpPivotValue) -> Any:
569641
"""
570642
Decode expr to get the pivot value.
@@ -1884,11 +1956,108 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
18841956
case "sp_table":
18851957
assert expr.sp_table.HasField("name")
18861958
table_name = self.decode_name_expr(expr.sp_table.name)
1887-
return self.session.table(
1888-
table_name,
1889-
is_temp_table_for_cleanup=expr.sp_table.is_temp_table_for_cleanup
1959+
is_temp_table = (
1960+
expr.sp_table.is_temp_table_for_cleanup
18901961
if hasattr(expr.sp_table, "is_temp_table_for_cleanup")
1891-
else False,
1962+
else False
1963+
)
1964+
match expr.sp_table.variant.WhichOneof("variant"):
1965+
case "sp_session_table":
1966+
return self.session.table(
1967+
table_name,
1968+
is_temp_table_for_cleanup=is_temp_table,
1969+
)
1970+
case "sp_table_init":
1971+
return Table(table_name, self.session, is_temp_table)
1972+
case _:
1973+
raise ValueError(
1974+
"Unknown table type: %s"
1975+
% expr.sp_table.WhichOneof("variant")
1976+
)
1977+
1978+
case "sp_table_delete":
1979+
table = self.symbol_table[expr.sp_table_delete.id.bitfield1][1]
1980+
block = expr.sp_table_delete.block
1981+
condition = (
1982+
self.decode_expr(expr.sp_table_delete.condition)
1983+
if expr.sp_table_delete.HasField("condition")
1984+
else None
1985+
)
1986+
source = (
1987+
self.decode_expr(expr.sp_table_delete.source)
1988+
if expr.sp_table_delete.HasField("source")
1989+
else None
1990+
)
1991+
statement_params = self.get_statement_params(
1992+
MessageToDict(expr.sp_table_delete)
1993+
)
1994+
return table.delete(
1995+
condition=condition,
1996+
source=source,
1997+
statement_params=statement_params,
1998+
block=block,
1999+
)
2000+
2001+
case "sp_table_drop_table":
2002+
table = self.symbol_table[expr.sp_table_drop_table.id.bitfield1][1]
2003+
return table.drop_table()
2004+
2005+
case "sp_table_merge":
2006+
table = self.symbol_table[expr.sp_table_merge.id.bitfield1][1]
2007+
block = expr.sp_table_merge.block
2008+
clauses = [
2009+
self.decode_matched_clause(clause)
2010+
for clause in expr.sp_table_merge.clauses
2011+
]
2012+
join_expr = self.decode_expr(expr.sp_table_merge.join_expr)
2013+
source = self.decode_expr(expr.sp_table_merge.source)
2014+
statement_params = self.get_statement_params(
2015+
MessageToDict(expr.sp_table_merge)
2016+
)
2017+
return table.merge(
2018+
source=source,
2019+
join_expr=join_expr,
2020+
clauses=clauses,
2021+
statement_params=statement_params,
2022+
block=block,
2023+
)
2024+
2025+
case "sp_table_sample":
2026+
df = self.decode_expr(expr.sp_table_sample.df)
2027+
num = expr.sp_table_sample.num.value
2028+
probability_fraction = expr.sp_table_sample.probability_fraction.value
2029+
sampling_method = expr.sp_table_sample.sampling_method.value
2030+
seed = expr.sp_table_sample.seed.value
2031+
return df.sample(
2032+
frac=probability_fraction,
2033+
n=num,
2034+
seed=seed,
2035+
sampling_method=sampling_method,
2036+
)
2037+
2038+
case "sp_table_update":
2039+
table = self.symbol_table[expr.sp_table_update.id.bitfield1][1]
2040+
assignments = self.decode_dsl_map_expr(expr.sp_table_update.assignments)
2041+
block = expr.sp_table_update.block
2042+
condition = (
2043+
self.decode_expr(expr.sp_table_update.condition)
2044+
if expr.sp_table_update.HasField("condition")
2045+
else None
2046+
)
2047+
source = (
2048+
self.decode_expr(expr.sp_table_update.source)
2049+
if expr.sp_table_update.HasField("source")
2050+
else None
2051+
)
2052+
statement_params = self.get_statement_params(
2053+
MessageToDict(expr.sp_table_update)
2054+
)
2055+
return table.update(
2056+
assignments,
2057+
condition,
2058+
source,
2059+
statement_params=statement_params,
2060+
block=block,
18922061
)
18932062

18942063
case "sp_to_snowpark_pandas":

0 commit comments

Comments
 (0)