@@ -322,10 +322,10 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover
322322
323323
324324def convert_groupby_apply_dataframe_result_to_standard_schema (
325- func_input_df : native_pd .DataFrame ,
326325 func_output_df : native_pd .DataFrame ,
327326 input_row_positions : native_pd .Series ,
328327 include_index_columns : bool ,
328+ is_transform : bool ,
329329) -> native_pd .DataFrame : # pragma: no cover: this function runs inside a UDTF, so coverage tools can't detect that we are testing it.
330330 """
331331 Take the result of applying the user-provided function to a dataframe, and convert it to a dataframe with known schema that we can output from a vUDTF.
@@ -338,6 +338,7 @@ def convert_groupby_apply_dataframe_result_to_standard_schema(
338338 func_input_df came from.
339339 include_index_columns: Whether to include the result's index columns in
340340 the output.
341+ is_transform: Whether the function is a transform or not.
341342
342343 Returns:
343344 A 5-column dataframe that represents the function result per the
@@ -346,7 +347,6 @@ def convert_groupby_apply_dataframe_result_to_standard_schema(
346347 """
347348 result_rows = []
348349 result_index_names = func_output_df .index .names
349- is_transform = func_output_df .index .equals (func_input_df .index )
350350 for row_number , (index_label , row ) in enumerate (func_output_df .iterrows ()):
351351 output_row_number = input_row_positions .iloc [row_number ] if is_transform else - 1
352352 if include_index_columns :
@@ -460,9 +460,7 @@ def apply_groupby_func_to_df(
460460 args : tuple ,
461461 kwargs : dict ,
462462 force_list_like_to_series : bool = False ,
463- ) -> Tuple [
464- native_pd .Series , native_pd .DataFrame , native_pd .DataFrame , bool
465- ]: # pragma: no cover
463+ ) -> Tuple [native_pd .Series , native_pd .DataFrame , bool , bool ]: # pragma: no cover
466464 """
467465 Restore input dataframe received in udtf to original schema.
468466 Args:
@@ -479,9 +477,9 @@ def apply_groupby_func_to_df(
479477 Returns:
480478 A Tuple of
481479 1. rows positions
482- 2. restored input dataframe.
483- 3. Result of applying the function to input dataframe .
484- 4. Whether final result should include index columns .
480+ 2. Result of applying the function to input dataframe.
481+ 3. Whether final result should include index columns .
482+ 4. Whether the index of the result is the same as the index of the input .
485483 """
486484 # The first column is row position. Save it for later.
487485 col_offset = 0
@@ -568,17 +566,20 @@ def apply_groupby_func_to_df(
568566 # columns of `func_result_as_frame`. For SeriesGroupBy, we
569567 # do include the result's index in the result.
570568 include_index_columns = series_groupby
569+ is_transform = input_object .index .equals (func_result_as_frame .index )
571570 elif isinstance (func_result , native_pd .DataFrame ):
572571 include_index_columns = True
573572 func_result_as_frame = func_result
573+ is_transform = input_object .index .equals (func_result_as_frame .index )
574574 else :
575575 # At this point, we know the function result was not a DataFrame
576576 # or Series
577577 include_index_columns = False
578578 func_result_as_frame = native_pd .DataFrame (
579579 {MODIN_UNNAMED_SERIES_LABEL : [func_result ]}
580580 )
581- return row_positions , input_object , func_result_as_frame , include_index_columns
581+ is_transform = False
582+ return row_positions , func_result_as_frame , include_index_columns , is_transform
582583
583584
584585def create_udtf_for_groupby_transform (
@@ -657,7 +658,7 @@ def end_partition(self, df: native_pd.DataFrame): # type: ignore[no-untyped-def
657658 A dataframe representing the result of applying the user-provided
658659 function to this group.
659660 """
660- row_positions , _ , func_result , _ = apply_groupby_func_to_df (
661+ row_positions , func_result , _ , _ = apply_groupby_func_to_df (
661662 df ,
662663 num_by ,
663664 index_column_names ,
@@ -879,9 +880,9 @@ def end_partition(self, df: native_pd.DataFrame): # type: ignore[no-untyped-def
879880 """
880881 (
881882 row_positions ,
882- input_object ,
883883 func_result ,
884884 include_index_columns ,
885+ is_transform_func ,
885886 ) = apply_groupby_func_to_df (
886887 df ,
887888 num_by ,
@@ -894,7 +895,7 @@ def end_partition(self, df: native_pd.DataFrame): # type: ignore[no-untyped-def
894895 force_list_like_to_series ,
895896 )
896897 return convert_groupby_apply_dataframe_result_to_standard_schema (
897- input_object , func_result , row_positions , include_index_columns
898+ func_result , row_positions , include_index_columns , is_transform_func
898899 )
899900
900901 input_types = [
0 commit comments