Skip to content

Commit 4d01c25

Browse files
authored
Don't remove double casts in cudf_polars (#20773)
closes pola-rs/polars#25564 This was added back in #16192, but it probably pre-dated some improvements in aggregation expressions. Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: #20773
1 parent 2b1da04 commit 4d01c25

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

python/cudf_polars/cudf_polars/dsl/translate.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,11 +1011,8 @@ def _(
10111011
# Push casts into literals so we can handle Cast(Literal(Null))
10121012
if isinstance(inner, expr.Literal):
10131013
return inner.astype(dtype)
1014-
elif isinstance(inner, expr.Cast):
1015-
# Translation of Len/Count-agg put in a cast, remove double
1016-
# casts if we have one.
1017-
(inner,) = inner.children
1018-
return expr.Cast(dtype, inner)
1014+
else:
1015+
return expr.Cast(dtype, inner)
10191016

10201017

10211018
@_translate_expr.register

python/cudf_polars/tests/expressions/test_casting.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,9 @@ def test_cast_unsupported(tests):
5252
assert_ir_translation_raises(
5353
df.select(pl.col("a").cast(totype)), NotImplementedError
5454
)
55+
56+
57+
def test_allow_double_cast():
58+
df = pl.LazyFrame({"c0": [1000]})
59+
query = df.select(pl.col("c0").cast(pl.Boolean).cast(pl.Int8))
60+
assert_gpu_result_equal(query)

0 commit comments

Comments
 (0)