Skip to content

Commit b3b0a55

Browse files
Got UDTFs to work, changed the UDTF test, got RelationalGroupedDataFrame to work, added a way to reset the recorded callables
1 parent a9baad2 commit b3b0a55

File tree

8 files changed

+540
-327
lines changed

8 files changed

+540
-327
lines changed

src/snowflake/snowpark/_internal/ast/batch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ def reset_id_gen(self) -> None:
7272
"""Resets the ID generator."""
7373
self._id_gen = itertools.count(start=1)
7474

75+
def reset_callables(self) -> None:
76+
"""Resets the callables."""
77+
self._callables = {}
78+
7579
def assign(self, symbol: Optional[str] = None) -> proto.Assign:
7680
"""
7781
Creates a new assignment statement.

src/snowflake/snowpark/_internal/udf_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,8 @@ def extract_return_input_types(
595595
return True, False, res_return_type, res_input_types
596596
elif isinstance(
597597
return_type_from_type_hints, PandasDataFrameType
598+
) and not isinstance(
599+
return_type, PandasDataFrameType
598600
): # vectorized UDTF
599601
return_type = PandasDataFrameType(
600602
[x.datatype for x in return_type], [x.name for x in return_type]

src/snowflake/snowpark/relational_grouped_dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def end_partition(self, pdf: pandas.DataFrame) -> pandas.DataFrame:
423423
_apply_in_pandas_udtf = self._dataframe._session.udtf.register(
424424
_ApplyInPandas,
425425
output_schema=output_schema,
426-
_emit_ast=_emit_ast,
426+
_emit_ast=False,
427427
**kwargs,
428428
)
429429
partition_by = [Column(expr, _emit_ast=False) for expr in self._grouping_exprs]

src/snowflake/snowpark/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ def json_value(self) -> Dict[str, Any]:
784784
def _fill_ast(self, ast: proto.SpDataType) -> None:
785785
ast.sp_struct_type.structured = self.structured
786786
for field in self.fields:
787-
field._fill_ast(ast.sp_struct_type.fields.add())
787+
field._fill_ast(ast.sp_struct_type.fields.list.add())
788788

789789

790790
class VariantType(DataType):

0 commit comments

Comments
 (0)