|
10 | 10 |
|
11 | 11 | from pandas import DataFrame as PandasDataFrame |
12 | 12 |
|
13 | | -from snowflake.snowpark.window import WindowSpec, Window, WindowRelativePosition |
| 13 | +from snowflake.snowpark.table import WhenMatchedClause, WhenNotMatchedClause |
| 14 | +from snowflake.snowpark.window import Window, WindowRelativePosition |
14 | 15 | import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto |
15 | 16 |
|
16 | 17 | from google.protobuf.json_format import MessageToDict |
17 | 18 |
|
18 | 19 | from snowflake.snowpark.relational_grouped_dataframe import GroupingSets |
19 | | -from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row |
| 20 | +from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row, Table |
20 | 21 | import snowflake.snowpark.functions |
21 | 22 | from snowflake.snowpark.functions import ( |
22 | 23 | udaf, |
|
25 | 26 | when, |
26 | 27 | sproc, |
27 | 28 | call_table_function, |
| 29 | + when_matched, |
| 30 | + when_not_matched, |
28 | 31 | ) |
29 | 32 | from snowflake.snowpark.types import ( |
30 | 33 | DataType, |
@@ -431,9 +434,10 @@ def decode_data_type_expr( |
431 | 434 | if isinstance( |
432 | 435 | data_type_expr.sp_pandas_data_frame_type.col_types, Iterable |
433 | 436 | ): |
434 | | - col_types = [] |
435 | | - for col_type in data_type_expr.sp_pandas_data_frame_type.col_types: |
436 | | - col_types.append(self.decode_data_type_expr(col_type)) |
| 437 | + col_types = [ |
| 438 | + self.decode_data_type_expr(col_type) |
| 439 | + for col_type in data_type_expr.sp_pandas_data_frame_type.col_types |
| 440 | + ] |
437 | 441 | else: |
438 | 442 | col_types = [data_type_expr.sp_pandas_data_frame_type.col_types] |
439 | 443 | if isinstance( |
@@ -565,6 +569,74 @@ def decode_join_type(self, join_type: proto.SpJoinType) -> str: |
565 | 569 | "Unknown join type: %s" % join_type.WhichOneof("variant") |
566 | 570 | ) |
567 | 571 |
|
| 572 | + def decode_matched_clause( |
| 573 | + self, matched_clause: proto.SpMatchedClause |
| 574 | + ) -> Union[WhenMatchedClause, WhenNotMatchedClause]: |
| 575 | + """ |
| 576 | + Decode a matched clause expression to get the clause. |
| 577 | +
|
| 578 | + Parameters |
| 579 | + ---------- |
| 580 | + matched_clause : proto.SpMatchedClause |
| 581 | + The expression to decode. |
| 582 | +
|
| 583 | + Returns |
| 584 | + ------- |
| 585 | + WhenMatchedClause or WhenNotMatchedClause |
| 586 | + The decoded clause. |
| 587 | + """ |
| 588 | + match matched_clause.WhichOneof("variant"): |
| 589 | + case "sp_merge_delete_when_matched_clause": |
| 590 | + condition = ( |
| 591 | + self.decode_expr( |
| 592 | + matched_clause.sp_merge_delete_when_matched_clause.condition |
| 593 | + ) |
| 594 | + if matched_clause.sp_merge_delete_when_matched_clause.HasField( |
| 595 | + "condition" |
| 596 | + ) |
| 597 | + else None |
| 598 | + ) |
| 599 | + return when_matched(condition).delete() |
| 600 | + case "sp_merge_insert_when_not_matched_clause": |
| 601 | + condition = ( |
| 602 | + self.decode_expr( |
| 603 | + matched_clause.sp_merge_insert_when_not_matched_clause.condition |
| 604 | + ) |
| 605 | + if matched_clause.sp_merge_insert_when_not_matched_clause.HasField( |
| 606 | + "condition" |
| 607 | + ) |
| 608 | + else None |
| 609 | + ) |
| 610 | + insert_keys = [ |
| 611 | + self.decode_expr(key) |
| 612 | + for key in matched_clause.sp_merge_insert_when_not_matched_clause.insert_keys.list |
| 613 | + ] |
| 614 | + insert_values = [ |
| 615 | + self.decode_expr(value) |
| 616 | + for value in matched_clause.sp_merge_insert_when_not_matched_clause.insert_values.list |
| 617 | + ] |
| 618 | + return when_not_matched(condition).insert( |
| 619 | + dict(zip(insert_keys, insert_values)) |
| 620 | + ) |
| 621 | + case "sp_merge_update_when_matched_clause": |
| 622 | + condition = ( |
| 623 | + self.decode_expr( |
| 624 | + matched_clause.sp_merge_update_when_matched_clause.condition |
| 625 | + ) |
| 626 | + if matched_clause.sp_merge_update_when_matched_clause.HasField( |
| 627 | + "condition" |
| 628 | + ) |
| 629 | + else None |
| 630 | + ) |
| 631 | + update_assignments = self.decode_dsl_map_expr( |
| 632 | + matched_clause.sp_merge_update_when_matched_clause.update_assignments.list |
| 633 | + ) |
| 634 | + return when_matched(condition).update(update_assignments) |
| 635 | + case _: |
| 636 | + raise ValueError( |
| 637 | + "Unknown matched clause: %s" % matched_clause.WhichOneof("variant") |
| 638 | + ) |
| 639 | + |
568 | 640 | def decode_pivot_value_expr(self, pivot_value_expr: proto.SpPivotValue) -> Any: |
569 | 641 | """ |
570 | 642 | Decode expr to get the pivot value. |
@@ -1884,11 +1956,108 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any: |
1884 | 1956 | case "sp_table": |
1885 | 1957 | assert expr.sp_table.HasField("name") |
1886 | 1958 | table_name = self.decode_name_expr(expr.sp_table.name) |
1887 | | - return self.session.table( |
1888 | | - table_name, |
1889 | | - is_temp_table_for_cleanup=expr.sp_table.is_temp_table_for_cleanup |
| 1959 | + is_temp_table = ( |
| 1960 | + expr.sp_table.is_temp_table_for_cleanup |
1890 | 1961 | if hasattr(expr.sp_table, "is_temp_table_for_cleanup") |
1891 | | - else False, |
| 1962 | + else False |
| 1963 | + ) |
| 1964 | + match expr.sp_table.variant.WhichOneof("variant"): |
| 1965 | + case "sp_session_table": |
| 1966 | + return self.session.table( |
| 1967 | + table_name, |
| 1968 | + is_temp_table_for_cleanup=is_temp_table, |
| 1969 | + ) |
| 1970 | + case "sp_table_init": |
| 1971 | + return Table(table_name, self.session, is_temp_table) |
| 1972 | + case _: |
| 1973 | + raise ValueError( |
| 1974 | + "Unknown table type: %s" |
| 1975 | + % expr.sp_table.WhichOneof("variant") |
| 1976 | + ) |
| 1977 | + |
| 1978 | + case "sp_table_delete": |
| 1979 | + table = self.symbol_table[expr.sp_table_delete.id.bitfield1][1] |
| 1980 | + block = expr.sp_table_delete.block |
| 1981 | + condition = ( |
| 1982 | + self.decode_expr(expr.sp_table_delete.condition) |
| 1983 | + if expr.sp_table_delete.HasField("condition") |
| 1984 | + else None |
| 1985 | + ) |
| 1986 | + source = ( |
| 1987 | + self.decode_expr(expr.sp_table_delete.source) |
| 1988 | + if expr.sp_table_delete.HasField("source") |
| 1989 | + else None |
| 1990 | + ) |
| 1991 | + statement_params = self.get_statement_params( |
| 1992 | + MessageToDict(expr.sp_table_delete) |
| 1993 | + ) |
| 1994 | + return table.delete( |
| 1995 | + condition=condition, |
| 1996 | + source=source, |
| 1997 | + statement_params=statement_params, |
| 1998 | + block=block, |
| 1999 | + ) |
| 2000 | + |
| 2001 | + case "sp_table_drop_table": |
| 2002 | + table = self.symbol_table[expr.sp_table_drop_table.id.bitfield1][1] |
| 2003 | + return table.drop_table() |
| 2004 | + |
| 2005 | + case "sp_table_merge": |
| 2006 | + table = self.symbol_table[expr.sp_table_merge.id.bitfield1][1] |
| 2007 | + block = expr.sp_table_merge.block |
| 2008 | + clauses = [ |
| 2009 | + self.decode_matched_clause(clause) |
| 2010 | + for clause in expr.sp_table_merge.clauses |
| 2011 | + ] |
| 2012 | + join_expr = self.decode_expr(expr.sp_table_merge.join_expr) |
| 2013 | + source = self.decode_expr(expr.sp_table_merge.source) |
| 2014 | + statement_params = self.get_statement_params( |
| 2015 | + MessageToDict(expr.sp_table_merge) |
| 2016 | + ) |
| 2017 | + return table.merge( |
| 2018 | + source=source, |
| 2019 | + join_expr=join_expr, |
| 2020 | + clauses=clauses, |
| 2021 | + statement_params=statement_params, |
| 2022 | + block=block, |
| 2023 | + ) |
| 2024 | + |
| 2025 | + case "sp_table_sample": |
| 2026 | + df = self.decode_expr(expr.sp_table_sample.df) |
| 2027 | + num = expr.sp_table_sample.num.value |
| 2028 | + probability_fraction = expr.sp_table_sample.probability_fraction.value |
| 2029 | + sampling_method = expr.sp_table_sample.sampling_method.value |
| 2030 | + seed = expr.sp_table_sample.seed.value |
| 2031 | + return df.sample( |
| 2032 | + frac=probability_fraction, |
| 2033 | + n=num, |
| 2034 | + seed=seed, |
| 2035 | + sampling_method=sampling_method, |
| 2036 | + ) |
| 2037 | + |
| 2038 | + case "sp_table_update": |
| 2039 | + table = self.symbol_table[expr.sp_table_update.id.bitfield1][1] |
| 2040 | + assignments = self.decode_dsl_map_expr(expr.sp_table_update.assignments) |
| 2041 | + block = expr.sp_table_update.block |
| 2042 | + condition = ( |
| 2043 | + self.decode_expr(expr.sp_table_update.condition) |
| 2044 | + if expr.sp_table_update.HasField("condition") |
| 2045 | + else None |
| 2046 | + ) |
| 2047 | + source = ( |
| 2048 | + self.decode_expr(expr.sp_table_update.source) |
| 2049 | + if expr.sp_table_update.HasField("source") |
| 2050 | + else None |
| 2051 | + ) |
| 2052 | + statement_params = self.get_statement_params( |
| 2053 | + MessageToDict(expr.sp_table_update) |
| 2054 | + ) |
| 2055 | + return table.update( |
| 2056 | + assignments, |
| 2057 | + condition, |
| 2058 | + source, |
| 2059 | + statement_params=statement_params, |
| 2060 | + block=block, |
1892 | 2061 | ) |
1893 | 2062 |
|
1894 | 2063 | case "sp_to_snowpark_pandas": |
|
0 commit comments