Skip to content

Commit d17d54f

Browse files
authored
feat: cast expression (#5389)
Without this, it is impossible to compare two columns with different bit-widths. Signed-off-by: Daniel King <[email protected]>
1 parent 7c76e43 commit d17d54f

File tree

4 files changed

+65
-7
lines changed

4 files changed

+65
-7
lines changed

vortex-python/python/vortex/_lib/expr.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ def root() -> Expr: ...
3333
def literal(dtype: DType, value: ScalarPyType) -> Expr: ...
3434
def not_(child: Expr) -> Expr: ...
3535
def and_(left: Expr, right: Expr) -> Expr: ...
36+
def cast(child: Expr, dtype: DType) -> Expr: ...

vortex-python/python/vortex/expr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44

5-
from ._lib.expr import Expr, and_, column, literal, not_, root # pyright: ignore[reportMissingModuleSource]
5+
from ._lib.expr import Expr, and_, cast, column, literal, not_, root # pyright: ignore[reportMissingModuleSource]
66

7-
__all__ = ["Expr", "column", "literal", "root", "not_", "and_"]
7+
__all__ = ["Expr", "column", "literal", "root", "not_", "and_", "cast"]

vortex-python/src/expr/mod.rs

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use pyo3::exceptions::PyValueError;
77
use pyo3::prelude::*;
88
use pyo3::types::*;
99
use vortex::dtype::{DType, Nullability, PType};
10+
use vortex::expr;
1011
use vortex::expr::{Binary, Expression, GetItem, Operator, VTableExt, and, lit, not};
1112

1213
use crate::arrays::PyArrayRef;
@@ -25,6 +26,7 @@ pub(crate) fn init(py: Python, parent: &Bound<PyModule>) -> PyResult<()> {
2526
m.add_function(wrap_pyfunction!(literal, &m)?)?;
2627
m.add_function(wrap_pyfunction!(not_, &m)?)?;
2728
m.add_function(wrap_pyfunction!(and_, &m)?)?;
29+
m.add_function(wrap_pyfunction!(cast, &m)?)?;
2830
m.add_class::<PyExpr>()?;
2931

3032
Ok(())
@@ -256,7 +258,7 @@ pub fn literal<'py>(
256258
#[pyfunction]
257259
pub fn root() -> PyExpr {
258260
PyExpr {
259-
inner: vortex::expr::root(),
261+
inner: expr::root(),
260262
}
261263
}
262264

@@ -290,7 +292,7 @@ pub fn column<'py>(name: &Bound<'py, PyString>) -> PyResult<Bound<'py, PyExpr>>
290292
Bound::new(
291293
py,
292294
PyExpr {
293-
inner: vortex::expr::get_item(name, vortex::expr::root()),
295+
inner: expr::get_item(name, expr::root()),
294296
},
295297
)
296298
}
@@ -342,7 +344,10 @@ pub fn not_(child: PyExpr) -> PyResult<PyExpr> {
342344
///
343345
/// Parameters
344346
/// ----------
345-
/// child : :class:`Any`
347+
/// left : :class:`Expr`
348+
/// A boolean expression.
349+
///
350+
/// right : :class:`Expr`
346351
/// A boolean expression.
347352
///
348353
/// Returns
@@ -364,3 +369,41 @@ pub fn and_(left: PyExpr, right: PyExpr) -> PyResult<PyExpr> {
364369
inner: and(left.inner, right.inner),
365370
})
366371
}
372+
373+
/// Cast an expression to a compatible type.
374+
///
375+
/// Parameters
376+
/// ----------
377+
/// child : :class:`Expr`
378+
/// The expression to cast.
379+
///
380+
/// Returns
381+
/// -------
382+
/// :class:`vortex.Expr`
383+
///
384+
/// Examples
385+
/// --------
386+
///
387+
/// Cast to a wider integer type:
388+
///
389+
/// ```python
390+
/// >>> import vortex.expr as ve
391+
/// >>> import vortex as vx
392+
/// >>> ve.cast(ve.literal(vx.int_(8), 1), vx.int_(16))
393+
/// <vortex.Expr object at ...>
394+
/// ```
395+
///
396+
/// Cast to a wider floating-point type:
397+
///
398+
/// ```python
399+
/// >>> import vortex.expr as ve
400+
/// >>> import vortex as vx
401+
/// >>> ve.cast(ve.literal(vx.float_(16), 3.145), vx.float_(64))
402+
/// <vortex.Expr object at ...>
403+
/// ```
404+
#[pyfunction]
405+
pub fn cast(child: PyExpr, dtype: PyDType) -> PyResult<PyExpr> {
406+
Ok(PyExpr {
407+
inner: expr::cast(child.into_inner(), dtype.into_inner()),
408+
})
409+
}

vortex-python/test/test_scan.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
import vortex as vx
11+
import vortex.expr as ve
1112
from vortex.scan import RepeatedScan
1213

1314

@@ -20,14 +21,19 @@ def record(x: int, columns: list[str] | set[str] | None = None) -> dict[str, int
2021

2122

2223
@pytest.fixture(scope="session")
23-
def vxscan(tmpdir_factory) -> vx.RepeatedScan: # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
24+
def vxscan(vxfile: vx.VortexFile) -> vx.RepeatedScan:
25+
return vxfile.to_repeated_scan()
26+
27+
28+
@pytest.fixture(scope="session")
29+
def vxfile(tmpdir_factory) -> vx.VortexFile: # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
2430
fname = tmpdir_factory.mktemp("data") / "foo.vortex" # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
2531

2632
if not os.path.exists(fname): # pyright: ignore[reportUnknownArgumentType]
2733
a = pa.array([record(x) for x in range(1_000)])
2834
arr = vx.compress(vx.array(a))
2935
vx.io.write(arr, str(fname)) # pyright: ignore[reportUnknownArgumentType]
30-
return vx.open(str(fname)).to_repeated_scan() # pyright: ignore[reportUnknownArgumentType]
36+
return vx.open(str(fname)) # pyright: ignore[reportUnknownArgumentType]
3137

3238

3339
def test_execute(vxscan: RepeatedScan):
@@ -50,3 +56,11 @@ def test_scalar_at(vxscan: RepeatedScan):
5056
"bool": True,
5157
"float": math.sqrt(10),
5258
}
59+
60+
61+
def test_scan_with_cast(vxfile: vx.VortexFile):
62+
actual = vxfile.scan(expr=ve.cast(ve.column("index"), vx.int_(16)) == ve.literal(vx.int_(16), 1)).read_all()
63+
expected = pa.array(
64+
[{"index": 1, "string": pa.scalar("1", pa.string_view()), "bool": False, "float": math.sqrt(1)}]
65+
)
66+
assert str(actual.to_arrow_array()) == str(expected)

0 commit comments

Comments
 (0)