Skip to content

Commit 7dfbcd6

Browse files
committed
feat: cast expression
Without this, it is impossible to compare two columns with different bit-widths. Signed-off-by: Daniel King <[email protected]>
1 parent 35d12a3 commit 7dfbcd6

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

vortex-python/src/expr/mod.rs

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use pyo3::exceptions::PyValueError;
77
use pyo3::prelude::*;
88
use pyo3::types::*;
99
use vortex::dtype::{DType, Nullability, PType};
10+
use vortex::expr;
1011
use vortex::expr::{Binary, Expression, GetItem, Operator, VTableExt, and, lit, not};
1112

1213
use crate::dtype::PyDType;
@@ -23,6 +24,7 @@ pub(crate) fn init(py: Python, parent: &Bound<PyModule>) -> PyResult<()> {
2324
m.add_function(wrap_pyfunction!(literal, &m)?)?;
2425
m.add_function(wrap_pyfunction!(not_, &m)?)?;
2526
m.add_function(wrap_pyfunction!(and_, &m)?)?;
27+
m.add_function(wrap_pyfunction!(cast, &m)?)?;
2628
m.add_class::<PyExpr>()?;
2729

2830
Ok(())
@@ -215,7 +217,7 @@ pub fn literal<'py>(
215217
#[pyfunction]
216218
pub fn root() -> PyExpr {
217219
PyExpr {
218-
inner: vortex::expr::root(),
220+
inner: expr::root(),
219221
}
220222
}
221223

@@ -249,7 +251,7 @@ pub fn column<'py>(name: &Bound<'py, PyString>) -> PyResult<Bound<'py, PyExpr>>
249251
Bound::new(
250252
py,
251253
PyExpr {
252-
inner: vortex::expr::get_item(name, vortex::expr::root()),
254+
inner: expr::get_item(name, expr::root()),
253255
},
254256
)
255257
}
@@ -301,7 +303,10 @@ pub fn not_(child: PyExpr) -> PyResult<PyExpr> {
301303
///
302304
/// Parameters
303305
/// ----------
304-
/// child : :class:`Any`
306+
/// left : :class:`Expr`
307+
/// A boolean expression.
308+
///
309+
/// right : :class:`Expr`
305310
/// A boolean expression.
306311
///
307312
/// Returns
@@ -323,3 +328,41 @@ pub fn and_(left: PyExpr, right: PyExpr) -> PyResult<PyExpr> {
323328
inner: and(left.inner, right.inner),
324329
})
325330
}
331+
332+
/// Cast an expression to a compatible type.
333+
///
334+
/// Parameters
335+
/// ----------
336+
/// child : :class:`Expr`
337+
/// The expression to cast.
338+
///
339+
/// Returns
340+
/// -------
341+
/// :class:`vortex.Expr`
342+
///
343+
/// Examples
344+
/// --------
345+
///
346+
/// Cast to a wider integer type:
347+
///
348+
/// ```python
349+
/// >>> import vortex.expr as ve
350+
/// >>> import vortex as vx
351+
/// >>> ve.cast(ve.literal(vx.int_(8), 1), vx.int_(16))
352+
/// <vortex.Expr object at ...>
353+
/// ```
354+
///
355+
/// Cast to a wider floating-point type:
356+
///
357+
/// ```python
358+
/// >>> import vortex.expr as ve
359+
/// >>> import vortex as vx
360+
/// >>> ve.cast(ve.literal(vx.float_(16), 3.145), vx.float_(64))
361+
/// <vortex.Expr object at ...>
362+
/// ```
363+
#[pyfunction]
364+
pub fn cast(child: PyExpr, dtype: PyDType) -> PyResult<PyExpr> {
365+
Ok(PyExpr {
366+
inner: expr::cast(child.into_inner(), dtype.into_inner()),
367+
})
368+
}

vortex-python/test/test_scan.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
import vortex as vx
11+
import vortex.expr as ve
1112
from vortex.scan import RepeatedScan
1213

1314

@@ -20,14 +21,19 @@ def record(x: int, columns: list[str] | set[str] | None = None) -> dict[str, int
2021

2122

2223
@pytest.fixture(scope="session")
23-
def vxscan(tmpdir_factory) -> vx.RepeatedScan: # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
24+
def vxscan(vxfile: vx.VortexFile) -> vx.RepeatedScan:
25+
return vxfile.to_repeated_scan()
26+
27+
28+
@pytest.fixture(scope="session")
29+
def vxfile(tmpdir_factory) -> vx.VortexFile: # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
2430
fname = tmpdir_factory.mktemp("data") / "foo.vortex" # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
2531

2632
if not os.path.exists(fname): # pyright: ignore[reportUnknownArgumentType]
2733
a = pa.array([record(x) for x in range(1_000)])
2834
arr = vx.compress(vx.array(a))
2935
vx.io.write(arr, str(fname)) # pyright: ignore[reportUnknownArgumentType]
30-
return vx.open(str(fname)).to_repeated_scan() # pyright: ignore[reportUnknownArgumentType]
36+
return vx.open(str(fname)) # pyright: ignore[reportUnknownArgumentType]
3137

3238

3339
def test_execute(vxscan: RepeatedScan):
@@ -50,3 +56,11 @@ def test_scalar_at(vxscan: RepeatedScan):
5056
"bool": True,
5157
"float": math.sqrt(10),
5258
}
59+
60+
61+
def test_scan_with_cast(vxfile: vx.VortexFile):
62+
actual = vxfile.scan(expr=ve.cast(ve.column("index"), vx.int_(16)) == ve.literal(vx.int_(16), 1)).read_all()
63+
expected = pa.array(
64+
[{"index": 1, "string": pa.scalar("1", pa.string_view()), "bool": False, "float": math.sqrt(1)}]
65+
)
66+
assert str(actual.to_arrow_array()) == str(expected)

0 commit comments

Comments
 (0)