Skip to content

Commit 1381dc3

Browse files
SNOW-1830556 Add decoder logic for Session.write_pandas (#2904)
1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1830556 2. Fill out the following pre-review checklist: - [ ] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [x] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development) 3. Please describe how your code solves the related issue. Added decoder logic for Session.write_pandas, fixed bug in decoder for sp_table, and verified shadowed_local_name.test works.
1 parent 4a5e57f commit 1381dc3

File tree

1 file changed

+60
-9
lines changed

1 file changed

+60
-9
lines changed

tests/ast/decoder.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from datetime import date, datetime, time, timedelta, timezone
99
from decimal import Decimal
1010

11-
from snowflake.snowpark.window import WindowSpec, Window, WindowRelativePosition
11+
from pandas import DataFrame as PandasDataFrame
1212

13+
from snowflake.snowpark.window import WindowSpec, Window, WindowRelativePosition
1314
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
1415

1516
from google.protobuf.json_format import MessageToDict
@@ -168,7 +169,7 @@ def convert_name_to_list(self, name: any) -> List:
168169
return [name]
169170
return [qualified_name for qualified_name in name]
170171

171-
def decode_name_expr(self, table_name: proto.SpName) -> str:
172+
def decode_name_expr(self, table_name: proto.SpName) -> Union[str, List]:
172173
"""
173174
Decode a table name expression to get the table name.
174175
@@ -185,7 +186,7 @@ def decode_name_expr(self, table_name: proto.SpName) -> str:
185186
if table_name.name.HasField("sp_name_flat"):
186187
return table_name.name.sp_name_flat.name
187188
elif table_name.name.HasField("sp_name_structured"):
188-
return table_name.name.sp_name_structured.name
189+
return [name for name in table_name.name.sp_name_structured.name]
189190
else:
190191
raise ValueError("Table name not found in proto.SpTableName")
191192

@@ -236,7 +237,9 @@ def decode_fn_ref_expr(self, fn_ref_expr: proto.FnRefExpr) -> str:
236237
% fn_ref_expr.WhichOneof("variant")
237238
)
238239

239-
def decode_dataframe_data_expr(self, df_data_expr: proto.SpDataframeData) -> List:
240+
def decode_dataframe_data_expr(
241+
self, df_data_expr: proto.SpDataframeData
242+
) -> Union[List, PandasDataFrame]:
240243
"""
241244
Decode a dataframe data expression to get the underlying data.
242245
@@ -247,7 +250,7 @@ def decode_dataframe_data_expr(self, df_data_expr: proto.SpDataframeData) -> Lis
247250
248251
Returns
249252
-------
250-
List
253+
List or pandas.DataFrame
251254
The decoded data.
252255
"""
253256
match df_data_expr.WhichOneof("sealed_value"):
@@ -265,14 +268,15 @@ def decode_dataframe_data_expr(self, df_data_expr: proto.SpDataframeData) -> Lis
265268
]
266269
else:
267270
return []
268-
# case "sp_dataframe_data__pandas":
269-
# pass
271+
case "sp_dataframe_data__pandas":
272+
# We don't know what pandas DataFrame was passed in, return an empty one.
273+
return PandasDataFrame()
270274
# case "sp_dataframe_data__tuple":
271275
# pass
272276
case _:
273277
raise ValueError(
274278
"Unknown dataframe data type: %s"
275-
% df_data_expr.WhichOneof("variant")
279+
% df_data_expr.WhichOneof("sealed_value")
276280
)
277281

278282
def decode_dataframe_schema_expr(
@@ -1750,7 +1754,12 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
17501754
case "sp_table":
17511755
assert expr.sp_table.HasField("name")
17521756
table_name = self.decode_name_expr(expr.sp_table.name)
1753-
return self.session.table(table_name)
1757+
return self.session.table(
1758+
table_name,
1759+
is_temp_table_for_cleanup=expr.sp_table.is_temp_table_for_cleanup
1760+
if hasattr(expr.sp_table, "is_temp_table_for_cleanup")
1761+
else False,
1762+
)
17541763

17551764
case "sp_to_snowpark_pandas":
17561765
df = self.decode_expr(expr.sp_to_snowpark_pandas.df)
@@ -2118,6 +2127,48 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
21182127
columns, rowcount=row_count, timelimit=time_limit_seconds
21192128
)
21202129

2130+
case "sp_write_pandas":
2131+
df = self.decode_dataframe_data_expr(expr.sp_write_pandas.df)
2132+
table_name = self.decode_name_expr(expr.sp_write_pandas.table_name)
2133+
if isinstance(table_name, str):
2134+
database, schema = None, None
2135+
else:
2136+
database, schema, table_name = (
2137+
table_name[0],
2138+
table_name[1],
2139+
table_name[2],
2140+
)
2141+
chunk_size = (
2142+
expr.sp_write_pandas.chunk_size.value
2143+
if expr.sp_write_pandas.HasField("chunk_size")
2144+
else None
2145+
)
2146+
compression = expr.sp_write_pandas.compression
2147+
on_error = expr.sp_write_pandas.on_error
2148+
parallel = expr.sp_write_pandas.parallel
2149+
quote_identifiers = expr.sp_write_pandas.quote_identifiers
2150+
auto_create_table = expr.sp_write_pandas.auto_create_table
2151+
create_temp_table = expr.sp_write_pandas.create_temp_table
2152+
overwrite = expr.sp_write_pandas.overwrite
2153+
table_type = expr.sp_write_pandas.table_type
2154+
kwargs = self.decode_dsl_map_expr(expr.sp_write_pandas.kwargs)
2155+
return self.session.write_pandas(
2156+
df,
2157+
table_name,
2158+
database=database,
2159+
schema=schema,
2160+
chunk_size=chunk_size,
2161+
compression=compression,
2162+
on_error=on_error,
2163+
parallel=parallel,
2164+
quote_identifiers=quote_identifiers,
2165+
auto_create_table=auto_create_table,
2166+
create_temp_table=create_temp_table,
2167+
overwrite=overwrite,
2168+
table_type=table_type,
2169+
**kwargs,
2170+
)
2171+
21212172
case "sp_row":
21222173
names = [name for name in expr.sp_row.names.list]
21232174
values = [self.decode_expr(value) for value in expr.sp_row.vs]

0 commit comments

Comments
 (0)