@@ -591,7 +591,7 @@ def test_group_by_grouping_sets(session):
591591 .with_column ("medical_license" , lit (None ))
592592 .select ("medical_license" , "radio_license" , "count" )
593593 )
594- .sort (col ("count" ))
594+ .sort (col ("count" ), col ( "radio_license" ) )
595595 .collect ()
596596 )
597597
@@ -601,16 +601,16 @@ def test_group_by_grouping_sets(session):
601601 GroupingSets ([col ("medical_license" )], [col ("radio_license" )])
602602 )
603603 .agg (count (col ("*" )).as_ ("count" ))
604- .sort (col ("count" ))
604+ .sort (col ("count" ), col ( "radio_license" ) )
605605 )
606606
607607 Utils .check_answer (grouping_sets , result , sort = False )
608608
609609 Utils .check_answer (
610610 grouping_sets ,
611611 [
612- Row (None , "General" , 1 ),
613612 Row (None , "Amateur Extra" , 1 ),
613+ Row (None , "General" , 1 ),
614614 Row ("RN" , None , 2 ),
615615 Row (None , "Technician" , 2 ),
616616 Row (None , None , 3 ),
@@ -624,8 +624,8 @@ def test_group_by_grouping_sets(session):
624624 TestData .nurse (session )
625625 .group_by ("medical_license" , "radio_license" )
626626 .agg (count (col ("*" )).as_ ("count" ))
627- .sort ( col ( "count" ), col ( "medical_license" ), col ( "radio_license" ) )
628- .select ( "count" , "medical_license" , "radio_license" ),
627+ .select ( "count" , "medical_license" , "radio_license" )
628+ .sort ( col ( "count" ), col ( "medical_license" ), col ( "radio_license" ) ),
629629 [
630630 Row (1 , "LVN" , "General" ),
631631 Row (1 , "RN" , None ),
@@ -775,14 +775,20 @@ def test_rel_grouped_dataframe_median(session):
775775def test_builtin_functions (session ):
776776 df = session .create_dataframe ([(1 , 11 ), (2 , 12 ), (1 , 13 )]).to_df (["a" , "b" ])
777777
778- assert df .group_by ("a" ).builtin ("max" )(col ("a" ), col ("b" )).collect () == [
779- Row (1 , 1 , 13 ),
780- Row (2 , 2 , 12 ),
781- ]
782- assert df .group_by ("a" ).builtin ("max" )(col ("b" )).collect () == [
783- Row (1 , 13 ),
784- Row (2 , 12 ),
785- ]
778+ assert Utils .check_answer (
779+ df .group_by ("a" ).builtin ("max" )(col ("a" ), col ("b" )),
780+ [
781+ Row (1 , 1 , 13 ),
782+ Row (2 , 2 , 12 ),
783+ ],
784+ )
785+ assert Utils .check_answer (
786+ df .group_by ("a" ).builtin ("max" )(col ("b" )),
787+ [
788+ Row (1 , 13 ),
789+ Row (2 , 12 ),
790+ ],
791+ )
786792
787793
788794def test_non_empty_arg_functions (session ):
@@ -828,30 +834,35 @@ def test_non_empty_arg_functions(session):
828834
829835
830836def test_null_count (session ):
831- assert TestData .test_data3 (session ).group_by ("a" ).agg (
832- count (col ("b" ))
833- ).collect () == [
834- Row (1 , 0 ),
835- Row (2 , 1 ),
836- ]
837+ assert Utils .check_answer (
838+ TestData .test_data3 (session ).group_by ("a" ).agg (count (col ("b" ))),
839+ [Row (1 , 0 ), Row (2 , 1 )],
840+ )
837841
838- assert TestData .test_data3 (session ).group_by ("a" ).agg (
839- count (col ("a" ) + col ("b" ))
840- ).collect () == [Row (1 , 0 ), Row (2 , 1 )]
842+ assert Utils .check_answer (
843+ TestData .test_data3 (session ).group_by ("a" ).agg (count (col ("a" ) + col ("b" ))),
844+ [Row (1 , 0 ), Row (2 , 1 )],
845+ )
841846
842- assert TestData .test_data3 (session ).agg (
843- [
844- count (col ("a" )),
845- count (col ("b" )),
846- count (lit (1 )),
847- count_distinct (col ("a" )),
848- count_distinct (col ("b" )),
849- ]
850- ).collect () == [Row (2 , 1 , 2 , 2 , 1 )]
847+ assert Utils .check_answer (
848+ TestData .test_data3 (session ).agg (
849+ [
850+ count (col ("a" )),
851+ count (col ("b" )),
852+ count (lit (1 )),
853+ count_distinct (col ("a" )),
854+ count_distinct (col ("b" )),
855+ ]
856+ ),
857+ [Row (2 , 1 , 2 , 2 , 1 )],
858+ )
851859
852- assert TestData .test_data3 (session ).agg (
853- [count (col ("b" )), count_distinct (col ("b" )), sum_distinct (col ("b" ))]
854- ).collect () == [Row (1 , 1 , 2 )]
860+ assert Utils .check_answer (
861+ TestData .test_data3 (session ).agg (
862+ [count (col ("b" )), count_distinct (col ("b" )), sum_distinct (col ("b" ))]
863+ ),
864+ [Row (1 , 1 , 2 )],
865+ )
855866
856867
857868def test_distinct (session ):
@@ -1143,13 +1154,19 @@ def test_aggregate_function_in_groupby(session):
11431154
11441155
11451156def test_ints_in_agg_exprs_are_taken_as_groupby_ordinal (session ):
1146- assert TestData .test_data2 (session ).group_by (lit (3 ), lit (4 )).agg (
1147- [lit (6 ), lit (7 ), sum (col ("b" ))]
1148- ).collect () == [Row (3 , 4 , 6 , 7 , 9 )]
1157+ assert Utils .check_answer (
1158+ TestData .test_data2 (session )
1159+ .group_by (lit (3 ), lit (4 ))
1160+ .agg ([lit (6 ), lit (7 ), sum (col ("b" ))]),
1161+ [Row (3 , 4 , 6 , 7 , 9 )],
1162+ )
11491163
1150- assert TestData .test_data2 (session ).group_by ([lit (3 ), lit (4 )]).agg (
1151- [lit (6 ), col ("b" ), sum (col ("b" ))]
1152- ).collect () == [Row (3 , 4 , 6 , 1 , 3 ), Row (3 , 4 , 6 , 2 , 6 )]
1164+ assert Utils .check_answer (
1165+ TestData .test_data2 (session )
1166+ .group_by ([lit (3 ), lit (4 )])
1167+ .agg ([lit (6 ), col ("b" ), sum (col ("b" ))]),
1168+ [Row (3 , 4 , 6 , 1 , 3 ), Row (3 , 4 , 6 , 2 , 6 )],
1169+ )
11531170
11541171
11551172@pytest .mark .xfail (
0 commit comments