Skip to content

Commit c49ed6a

Browse files
committed
fix flaky tests that fail due to sort order
1 parent 7ceab4a commit c49ed6a

File tree

1 file changed

+57
-40
lines changed

1 file changed

+57
-40
lines changed

tests/integ/scala/test_dataframe_aggregate_suite.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def test_group_by_grouping_sets(session):
591591
.with_column("medical_license", lit(None))
592592
.select("medical_license", "radio_license", "count")
593593
)
594-
.sort(col("count"))
594+
.sort(col("count"), col("radio_license"))
595595
.collect()
596596
)
597597

@@ -601,16 +601,16 @@ def test_group_by_grouping_sets(session):
601601
GroupingSets([col("medical_license")], [col("radio_license")])
602602
)
603603
.agg(count(col("*")).as_("count"))
604-
.sort(col("count"))
604+
.sort(col("count"), col("radio_license"))
605605
)
606606

607607
Utils.check_answer(grouping_sets, result, sort=False)
608608

609609
Utils.check_answer(
610610
grouping_sets,
611611
[
612-
Row(None, "General", 1),
613612
Row(None, "Amateur Extra", 1),
613+
Row(None, "General", 1),
614614
Row("RN", None, 2),
615615
Row(None, "Technician", 2),
616616
Row(None, None, 3),
@@ -624,8 +624,8 @@ def test_group_by_grouping_sets(session):
624624
TestData.nurse(session)
625625
.group_by("medical_license", "radio_license")
626626
.agg(count(col("*")).as_("count"))
627-
.sort(col("count"), col("medical_license"), col("radio_license"))
628-
.select("count", "medical_license", "radio_license"),
627+
.select("count", "medical_license", "radio_license")
628+
.sort(col("count"), col("medical_license"), col("radio_license")),
629629
[
630630
Row(1, "LVN", "General"),
631631
Row(1, "RN", None),
@@ -775,14 +775,20 @@ def test_rel_grouped_dataframe_median(session):
775775
def test_builtin_functions(session):
776776
df = session.create_dataframe([(1, 11), (2, 12), (1, 13)]).to_df(["a", "b"])
777777

778-
assert df.group_by("a").builtin("max")(col("a"), col("b")).collect() == [
779-
Row(1, 1, 13),
780-
Row(2, 2, 12),
781-
]
782-
assert df.group_by("a").builtin("max")(col("b")).collect() == [
783-
Row(1, 13),
784-
Row(2, 12),
785-
]
778+
assert Utils.check_answer(
779+
df.group_by("a").builtin("max")(col("a"), col("b")),
780+
[
781+
Row(1, 1, 13),
782+
Row(2, 2, 12),
783+
],
784+
)
785+
assert Utils.check_answer(
786+
df.group_by("a").builtin("max")(col("b")),
787+
[
788+
Row(1, 13),
789+
Row(2, 12),
790+
],
791+
)
786792

787793

788794
def test_non_empty_arg_functions(session):
@@ -828,30 +834,35 @@ def test_non_empty_arg_functions(session):
828834

829835

830836
def test_null_count(session):
831-
assert TestData.test_data3(session).group_by("a").agg(
832-
count(col("b"))
833-
).collect() == [
834-
Row(1, 0),
835-
Row(2, 1),
836-
]
837+
assert Utils.check_answer(
838+
TestData.test_data3(session).group_by("a").agg(count(col("b"))),
839+
[Row(1, 0), Row(2, 1)],
840+
)
837841

838-
assert TestData.test_data3(session).group_by("a").agg(
839-
count(col("a") + col("b"))
840-
).collect() == [Row(1, 0), Row(2, 1)]
842+
assert Utils.check_answer(
843+
TestData.test_data3(session).group_by("a").agg(count(col("a") + col("b"))),
844+
[Row(1, 0), Row(2, 1)],
845+
)
841846

842-
assert TestData.test_data3(session).agg(
843-
[
844-
count(col("a")),
845-
count(col("b")),
846-
count(lit(1)),
847-
count_distinct(col("a")),
848-
count_distinct(col("b")),
849-
]
850-
).collect() == [Row(2, 1, 2, 2, 1)]
847+
assert Utils.check_answer(
848+
TestData.test_data3(session).agg(
849+
[
850+
count(col("a")),
851+
count(col("b")),
852+
count(lit(1)),
853+
count_distinct(col("a")),
854+
count_distinct(col("b")),
855+
]
856+
),
857+
[Row(2, 1, 2, 2, 1)],
858+
)
851859

852-
assert TestData.test_data3(session).agg(
853-
[count(col("b")), count_distinct(col("b")), sum_distinct(col("b"))]
854-
).collect() == [Row(1, 1, 2)]
860+
assert Utils.check_answer(
861+
TestData.test_data3(session).agg(
862+
[count(col("b")), count_distinct(col("b")), sum_distinct(col("b"))]
863+
),
864+
[Row(1, 1, 2)],
865+
)
855866

856867

857868
def test_distinct(session):
@@ -1143,13 +1154,19 @@ def test_aggregate_function_in_groupby(session):
11431154

11441155

11451156
def test_ints_in_agg_exprs_are_taken_as_groupby_ordinal(session):
1146-
assert TestData.test_data2(session).group_by(lit(3), lit(4)).agg(
1147-
[lit(6), lit(7), sum(col("b"))]
1148-
).collect() == [Row(3, 4, 6, 7, 9)]
1157+
assert Utils.check_answer(
1158+
TestData.test_data2(session)
1159+
.group_by(lit(3), lit(4))
1160+
.agg([lit(6), lit(7), sum(col("b"))]),
1161+
[Row(3, 4, 6, 7, 9)],
1162+
)
11491163

1150-
assert TestData.test_data2(session).group_by([lit(3), lit(4)]).agg(
1151-
[lit(6), col("b"), sum(col("b"))]
1152-
).collect() == [Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6)]
1164+
assert Utils.check_answer(
1165+
TestData.test_data2(session)
1166+
.group_by([lit(3), lit(4)])
1167+
.agg([lit(6), col("b"), sum(col("b"))]),
1168+
[Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6)],
1169+
)
11531170

11541171

11551172
@pytest.mark.xfail(

0 commit comments

Comments
 (0)