|
10 | 10 |
|
11 | 11 | from pandas import DataFrame as PandasDataFrame |
12 | 12 |
|
| 13 | +from snowflake.snowpark.table import WhenMatchedClause, WhenNotMatchedClause |
13 | 14 | from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SaveMode |
14 | 15 | from snowflake.snowpark.window import WindowSpec, Window, WindowRelativePosition |
15 | 16 | import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto |
16 | 17 |
|
17 | 18 | from google.protobuf.json_format import MessageToDict |
18 | 19 |
|
19 | 20 | from snowflake.snowpark.relational_grouped_dataframe import GroupingSets |
20 | | -from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row |
| 21 | +from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row, Table |
21 | 22 | import snowflake.snowpark.functions |
22 | 23 | from snowflake.snowpark.functions import ( |
23 | 24 | udaf, |
|
26 | 27 | when, |
27 | 28 | sproc, |
28 | 29 | call_table_function, |
| 30 | + when_matched, |
| 31 | + when_not_matched, |
29 | 32 | ) |
30 | 33 | from snowflake.snowpark.types import ( |
31 | 34 | DataType, |
@@ -432,9 +435,10 @@ def decode_data_type_expr( |
432 | 435 | if isinstance( |
433 | 436 | data_type_expr.sp_pandas_data_frame_type.col_types, Iterable |
434 | 437 | ): |
435 | | - col_types = [] |
436 | | - for col_type in data_type_expr.sp_pandas_data_frame_type.col_types: |
437 | | - col_types.append(self.decode_data_type_expr(col_type)) |
| 438 | + col_types = [ |
| 439 | + self.decode_data_type_expr(col_type) |
| 440 | + for col_type in data_type_expr.sp_pandas_data_frame_type.col_types |
| 441 | + ] |
438 | 442 | else: |
439 | 443 | col_types = [data_type_expr.sp_pandas_data_frame_type.col_types] |
440 | 444 | if isinstance( |
@@ -566,6 +570,74 @@ def decode_join_type(self, join_type: proto.SpJoinType) -> str: |
566 | 570 | "Unknown join type: %s" % join_type.WhichOneof("variant") |
567 | 571 | ) |
568 | 572 |
|
| 573 | + def decode_matched_clause( |
| 574 | + self, matched_clause: proto.SpMatchedClause |
| 575 | + ) -> Union[WhenMatchedClause, WhenNotMatchedClause]: |
| 576 | + """ |
| 577 | + Decode a matched clause expression to get the clause. |
| 578 | +
|
| 579 | + Parameters |
| 580 | + ---------- |
| 581 | + matched_clause : proto.SpMatchedClause |
| 582 | + The expression to decode. |
| 583 | +
|
| 584 | + Returns |
| 585 | + ------- |
| 586 | + WhenMatchedClause or WhenNotMatchedClause |
| 587 | + The decoded clause. |
| 588 | + """ |
| 589 | + match matched_clause.WhichOneof("variant"): |
| 590 | + case "sp_merge_delete_when_matched_clause": |
| 591 | + condition = ( |
| 592 | + self.decode_expr( |
| 593 | + matched_clause.sp_merge_delete_when_matched_clause.condition |
| 594 | + ) |
| 595 | + if matched_clause.sp_merge_delete_when_matched_clause.HasField( |
| 596 | + "condition" |
| 597 | + ) |
| 598 | + else None |
| 599 | + ) |
| 600 | + return when_matched(condition).delete() |
| 601 | + case "sp_merge_insert_when_not_matched_clause": |
| 602 | + condition = ( |
| 603 | + self.decode_expr( |
| 604 | + matched_clause.sp_merge_insert_when_not_matched_clause.condition |
| 605 | + ) |
| 606 | + if matched_clause.sp_merge_insert_when_not_matched_clause.HasField( |
| 607 | + "condition" |
| 608 | + ) |
| 609 | + else None |
| 610 | + ) |
| 611 | + insert_keys = [ |
| 612 | + self.decode_expr(key) |
| 613 | + for key in matched_clause.sp_merge_insert_when_not_matched_clause.insert_keys.list |
| 614 | + ] |
| 615 | + insert_values = [ |
| 616 | + self.decode_expr(value) |
| 617 | + for value in matched_clause.sp_merge_insert_when_not_matched_clause.insert_values.list |
| 618 | + ] |
| 619 | + return when_not_matched(condition).insert( |
| 620 | + dict(zip(insert_keys, insert_values)) |
| 621 | + ) |
| 622 | + case "sp_merge_update_when_matched_clause": |
| 623 | + condition = ( |
| 624 | + self.decode_expr( |
| 625 | + matched_clause.sp_merge_update_when_matched_clause.condition |
| 626 | + ) |
| 627 | + if matched_clause.sp_merge_update_when_matched_clause.HasField( |
| 628 | + "condition" |
| 629 | + ) |
| 630 | + else None |
| 631 | + ) |
| 632 | + update_assignments = self.decode_dsl_map_expr( |
| 633 | + matched_clause.sp_merge_update_when_matched_clause.update_assignments.list |
| 634 | + ) |
| 635 | + return when_matched(condition).update(update_assignments) |
| 636 | + case _: |
| 637 | + raise ValueError( |
| 638 | + "Unknown matched clause: %s" % matched_clause.WhichOneof("variant") |
| 639 | + ) |
| 640 | + |
569 | 641 | def decode_pivot_value_expr(self, pivot_value_expr: proto.SpPivotValue) -> Any: |
570 | 642 | """ |
571 | 643 | Decode expr to get the pivot value. |
@@ -1918,11 +1990,108 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: |
1918 | 1990 | case "sp_table": |
1919 | 1991 | assert expr.sp_table.HasField("name") |
1920 | 1992 | table_name = self.decode_name_expr(expr.sp_table.name) |
1921 | | - return self.session.table( |
1922 | | - table_name, |
1923 | | - is_temp_table_for_cleanup=expr.sp_table.is_temp_table_for_cleanup |
| 1993 | + is_temp_table = ( |
| 1994 | + expr.sp_table.is_temp_table_for_cleanup |
1924 | 1995 | if hasattr(expr.sp_table, "is_temp_table_for_cleanup") |
1925 | | - else False, |
| 1996 | + else False |
| 1997 | + ) |
| 1998 | + match expr.sp_table.variant.WhichOneof("variant"): |
| 1999 | + case "sp_session_table": |
| 2000 | + return self.session.table( |
| 2001 | + table_name, |
| 2002 | + is_temp_table_for_cleanup=is_temp_table, |
| 2003 | + ) |
| 2004 | + case "sp_table_init": |
| 2005 | + return Table(table_name, self.session, is_temp_table) |
| 2006 | + case _: |
| 2007 | + raise ValueError( |
| 2008 | + "Unknown table type: %s" |
| 2009 | + % expr.sp_table.WhichOneof("variant") |
| 2010 | + ) |
| 2011 | + |
| 2012 | + case "sp_table_delete": |
| 2013 | + table = self.symbol_table[expr.sp_table_delete.id.bitfield1][1] |
| 2014 | + block = expr.sp_table_delete.block |
| 2015 | + condition = ( |
| 2016 | + self.decode_expr(expr.sp_table_delete.condition) |
| 2017 | + if expr.sp_table_delete.HasField("condition") |
| 2018 | + else None |
| 2019 | + ) |
| 2020 | + source = ( |
| 2021 | + self.decode_expr(expr.sp_table_delete.source) |
| 2022 | + if expr.sp_table_delete.HasField("source") |
| 2023 | + else None |
| 2024 | + ) |
| 2025 | + statement_params = self.get_statement_params( |
| 2026 | + MessageToDict(expr.sp_table_delete) |
| 2027 | + ) |
| 2028 | + return table.delete( |
| 2029 | + condition=condition, |
| 2030 | + source=source, |
| 2031 | + statement_params=statement_params, |
| 2032 | + block=block, |
| 2033 | + ) |
| 2034 | + |
| 2035 | + case "sp_table_drop_table": |
| 2036 | + table = self.symbol_table[expr.sp_table_drop_table.id.bitfield1][1] |
| 2037 | + return table.drop_table() |
| 2038 | + |
| 2039 | + case "sp_table_merge": |
| 2040 | + table = self.symbol_table[expr.sp_table_merge.id.bitfield1][1] |
| 2041 | + block = expr.sp_table_merge.block |
| 2042 | + clauses = [ |
| 2043 | + self.decode_matched_clause(clause) |
| 2044 | + for clause in expr.sp_table_merge.clauses |
| 2045 | + ] |
| 2046 | + join_expr = self.decode_expr(expr.sp_table_merge.join_expr) |
| 2047 | + source = self.decode_expr(expr.sp_table_merge.source) |
| 2048 | + statement_params = self.get_statement_params( |
| 2049 | + MessageToDict(expr.sp_table_merge) |
| 2050 | + ) |
| 2051 | + return table.merge( |
| 2052 | + source=source, |
| 2053 | + join_expr=join_expr, |
| 2054 | + clauses=clauses, |
| 2055 | + statement_params=statement_params, |
| 2056 | + block=block, |
| 2057 | + ) |
| 2058 | + |
| 2059 | + case "sp_table_sample": |
| 2060 | + df = self.decode_expr(expr.sp_table_sample.df) |
| 2061 | + num = expr.sp_table_sample.num.value |
| 2062 | + probability_fraction = expr.sp_table_sample.probability_fraction.value |
| 2063 | + sampling_method = expr.sp_table_sample.sampling_method.value |
| 2064 | + seed = expr.sp_table_sample.seed.value |
| 2065 | + return df.sample( |
| 2066 | + frac=probability_fraction, |
| 2067 | + n=num, |
| 2068 | + seed=seed, |
| 2069 | + sampling_method=sampling_method, |
| 2070 | + ) |
| 2071 | + |
| 2072 | + case "sp_table_update": |
| 2073 | + table = self.symbol_table[expr.sp_table_update.id.bitfield1][1] |
| 2074 | + assignments = self.decode_dsl_map_expr(expr.sp_table_update.assignments) |
| 2075 | + block = expr.sp_table_update.block |
| 2076 | + condition = ( |
| 2077 | + self.decode_expr(expr.sp_table_update.condition) |
| 2078 | + if expr.sp_table_update.HasField("condition") |
| 2079 | + else None |
| 2080 | + ) |
| 2081 | + source = ( |
| 2082 | + self.decode_expr(expr.sp_table_update.source) |
| 2083 | + if expr.sp_table_update.HasField("source") |
| 2084 | + else None |
| 2085 | + ) |
| 2086 | + statement_params = self.get_statement_params( |
| 2087 | + MessageToDict(expr.sp_table_update) |
| 2088 | + ) |
| 2089 | + return table.update( |
| 2090 | + assignments, |
| 2091 | + condition, |
| 2092 | + source, |
| 2093 | + statement_params=statement_params, |
| 2094 | + block=block, |
1926 | 2095 | ) |
1927 | 2096 |
|
1928 | 2097 | case "sp_to_snowpark_pandas": |
|
0 commit comments