Skip to content

Commit 4982ebe

Browse files
Add vortex-expr GetItem, and update Select (#1836)
1 parent efd61cf commit 4982ebe

File tree

10 files changed

+255
-68
lines changed

10 files changed

+255
-68
lines changed

pyvortex/src/expr.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
use std::sync::Arc;
2+
13
use pyo3::exceptions::PyValueError;
24
use pyo3::prelude::*;
35
use pyo3::types::*;
46
use vortex::dtype::half::f16;
5-
use vortex::dtype::{DType, Nullability, PType};
6-
use vortex::expr::{col, lit, BinaryExpr, ExprRef, Operator};
7+
use vortex::dtype::{DType, Field, Nullability, PType};
8+
use vortex::expr::{col, lit, BinaryExpr, ExprRef, GetItem, Operator};
79
use vortex::scalar::Scalar;
810

911
use crate::dtype::PyDType;
@@ -115,6 +117,7 @@ use crate::dtype::PyDType;
115117
/// "Angela"
116118
/// ]
117119
#[pyclass(name = "Expr", module = "vortex")]
120+
#[derive(Clone)]
118121
pub struct PyExpr {
119122
inner: ExprRef,
120123
}
@@ -221,6 +224,10 @@ impl PyExpr {
221224
) -> PyResult<Bound<'py, PyExpr>> {
222225
py_binary_opeartor(self_, Operator::Or, coerce_expr(right)?)
223226
}
227+
228+
fn __getitem__(self_: PyRef<'_, Self>, field: PyObject) -> PyResult<PyExpr> {
229+
get_item(self_.py(), field, self_.clone())
230+
}
224231
}
225232

226233
/// A named column.
@@ -303,3 +310,20 @@ pub fn scalar_helper(dtype: DType, value: &Bound<'_, PyAny>) -> PyResult<Scalar>
303310
DType::Extension(..) => todo!(),
304311
}
305312
}
313+
314+
pub fn get_item(py: Python, field: PyObject, child: PyExpr) -> PyResult<PyExpr> {
315+
let field = if let Ok(value) = field.downcast_bound::<PyLong>(py) {
316+
Field::Index(value.extract()?)
317+
} else if let Ok(value) = field.downcast_bound::<PyString>(py) {
318+
Field::Name(Arc::from(value.extract::<String>()?.as_str()))
319+
} else {
320+
return Err(PyValueError::new_err(format!(
321+
"expected int, or str but found: {}",
322+
field
323+
)));
324+
};
325+
326+
Ok(PyExpr {
327+
inner: GetItem::new_expr(field, child.inner),
328+
})
329+
}

vortex-expr/src/get_item.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
use std::any::Any;
2+
use std::fmt::{Debug, Display, Formatter};
3+
use std::sync::Arc;
4+
5+
use vortex_array::ArrayData;
6+
use vortex_dtype::Field;
7+
use vortex_error::{vortex_err, VortexResult};
8+
9+
use crate::{ExprRef, VortexExpr};
10+
11+
#[derive(Debug, Clone, Eq)]
12+
pub struct GetItem {
13+
field: Field,
14+
child: ExprRef,
15+
}
16+
17+
impl GetItem {
18+
pub fn new_expr(field: impl Into<Field>, child: ExprRef) -> ExprRef {
19+
Arc::new(Self {
20+
field: field.into(),
21+
child,
22+
})
23+
}
24+
25+
pub fn field(&self) -> &Field {
26+
&self.field
27+
}
28+
29+
pub fn child(&self) -> &ExprRef {
30+
&self.child
31+
}
32+
}
33+
34+
pub fn get_item(field: impl Into<Field>, child: ExprRef) -> ExprRef {
35+
GetItem::new_expr(field, child)
36+
}
37+
38+
impl PartialEq<dyn Any> for GetItem {
39+
fn eq(&self, other: &dyn Any) -> bool {
40+
other
41+
.downcast_ref::<GetItem>()
42+
.map(|item| self.field == item.field && self.child.eq(&item.child))
43+
.unwrap_or(false)
44+
}
45+
}
46+
47+
impl Display for GetItem {
48+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
49+
write!(f, "{}.{}", self.child, self.field)
50+
}
51+
}
52+
53+
impl VortexExpr for GetItem {
54+
fn as_any(&self) -> &dyn Any {
55+
self
56+
}
57+
58+
fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
59+
let child = self.child.evaluate(batch)?;
60+
child
61+
.as_struct_array()
62+
.ok_or_else(|| vortex_err!("GetItem: child array into struct"))?
63+
// TODO(joe): apply struct validity
64+
.maybe_null_field(self.field())
65+
.ok_or_else(|| vortex_err!("Field {} not found", self.field))
66+
}
67+
68+
fn children(&self) -> Vec<&ExprRef> {
69+
vec![self.child()]
70+
}
71+
72+
fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
73+
assert_eq!(children.len(), 1);
74+
Self::new_expr(self.field().clone(), children[0].clone())
75+
}
76+
}
77+
78+
impl PartialEq for GetItem {
79+
fn eq(&self, other: &GetItem) -> bool {
80+
self.field == other.field && self.child.eq(&other.child)
81+
}
82+
}
83+
84+
#[cfg(test)]
85+
mod tests {
86+
use vortex_array::array::StructArray;
87+
use vortex_array::{ArrayDType, IntoArrayData};
88+
use vortex_buffer::buffer;
89+
use vortex_dtype::DType;
90+
use vortex_dtype::PType::{I32, I64};
91+
92+
use crate::get_item::get_item;
93+
use crate::ident;
94+
95+
fn test_array() -> StructArray {
96+
StructArray::from_fields(&[
97+
("a", buffer![0i32, 1, 2].into_array()),
98+
("b", buffer![4i64, 5, 6].into_array()),
99+
])
100+
.unwrap()
101+
}
102+
103+
#[test]
104+
pub fn get_item_by_name() {
105+
let st = test_array();
106+
let get_item = get_item("a", ident());
107+
let item = get_item.evaluate(st.as_ref()).unwrap();
108+
assert_eq!(item.dtype(), &DType::from(I32))
109+
}
110+
111+
#[test]
112+
pub fn get_item_by_name_none() {
113+
let st = test_array();
114+
let get_item = get_item("c", ident());
115+
assert!(get_item.evaluate(st.as_ref()).is_err());
116+
}
117+
118+
#[test]
119+
pub fn get_item_by_idx() {
120+
let st = test_array();
121+
let get_item = get_item(1, ident());
122+
let item = get_item.evaluate(st.as_ref()).unwrap();
123+
assert_eq!(item.dtype(), &DType::from(I64))
124+
}
125+
}

vortex-expr/src/lib.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::sync::Arc;
55
mod binary;
66
mod column;
77
pub mod datafusion;
8+
mod get_item;
89
mod identity;
910
mod like;
1011
mod literal;
@@ -20,6 +21,7 @@ mod traversal;
2021

2122
pub use binary::*;
2223
pub use column::*;
24+
pub use get_item::*;
2325
pub use identity::*;
2426
pub use like::*;
2527
pub use literal::*;
@@ -184,21 +186,21 @@ mod tests {
184186
assert_eq!(Not::new_expr(col1.clone()).to_string(), "!$col1");
185187

186188
assert_eq!(
187-
Select::include(vec![Field::from("col1")]).to_string(),
188-
"Include($col1)"
189+
Select::include_expr(vec![Field::from("col1")], ident()).to_string(),
190+
"select +($col1) []"
189191
);
190192
assert_eq!(
191-
Select::include(vec![Field::from("col1"), Field::from("col2")]).to_string(),
192-
"Include($col1,$col2)"
193+
Select::include_expr(vec![Field::from("col1"), Field::from("col2")], ident())
194+
.to_string(),
195+
"select +($col1,$col2) []"
193196
);
194197
assert_eq!(
195-
Select::exclude(vec![
196-
Field::from("col1"),
197-
Field::from("col2"),
198-
Field::Index(1),
199-
])
198+
Select::exclude_expr(
199+
vec![Field::from("col1"), Field::from("col2"), Field::Index(1),],
200+
ident()
201+
)
200202
.to_string(),
201-
"Exclude($col1,$col2,[1])"
203+
"select -($col1,$col2,[1]) []"
202204
);
203205

204206
assert_eq!(lit(Scalar::from(0_u8)).to_string(), "0_u8");

vortex-expr/src/project.rs

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use vortex_dtype::Field;
55

66
use crate::{
77
col, lit, BinaryExpr, Column, ExprRef, Identity, Like, Literal, Not, Operator, RowFilter,
8-
Select, VortexExpr, VortexExprExt,
8+
Select, SelectField, VortexExpr, VortexExprExt,
99
};
1010

1111
/// Restrict expression to only the fields that appear in projection
@@ -17,8 +17,8 @@ pub fn expr_project(expr: &ExprRef, projection: &[Field]) -> Option<ExprRef> {
1717
} else if expr.as_any().downcast_ref::<Literal>().is_some() {
1818
Some(expr.clone())
1919
} else if let Some(s) = expr.as_any().downcast_ref::<Select>() {
20-
match s {
21-
Select::Include(i) => {
20+
match s.fields() {
21+
SelectField::Include(i) => {
2222
let fields = i
2323
.iter()
2424
.filter(|f| projection.contains(f))
@@ -27,10 +27,10 @@ pub fn expr_project(expr: &ExprRef, projection: &[Field]) -> Option<ExprRef> {
2727
if projection.len() == 1 {
2828
Some(Arc::new(Identity))
2929
} else {
30-
(!fields.is_empty()).then(|| Arc::new(Select::include(fields)) as _)
30+
(!fields.is_empty()).then(|| Select::include_expr(fields, s.child().clone()))
3131
}
3232
}
33-
Select::Exclude(e) => {
33+
SelectField::Exclude(e) => {
3434
let fields = projection
3535
.iter()
3636
.filter(|f| !e.contains(f))
@@ -39,7 +39,7 @@ pub fn expr_project(expr: &ExprRef, projection: &[Field]) -> Option<ExprRef> {
3939
if projection.len() == 1 {
4040
Some(Arc::new(Identity))
4141
} else {
42-
(!fields.is_empty()).then(|| Arc::new(Select::include(fields)) as _)
42+
(!fields.is_empty()).then(|| Select::include_expr(fields, s.child().clone()))
4343
}
4444
}
4545
}
@@ -103,7 +103,7 @@ mod tests {
103103
use vortex_dtype::Field;
104104

105105
use super::*;
106-
use crate::{and, lt, or, Identity, Not, Select};
106+
use crate::{and, ident, lt, or, Identity, Not, Select};
107107

108108
#[test]
109109
fn project_and() {
@@ -141,29 +141,27 @@ mod tests {
141141

142142
#[test]
143143
fn project_select() {
144-
let include = Arc::new(Select::include(vec![
145-
Field::from("a"),
146-
Field::from("b"),
147-
Field::from("c"),
148-
])) as _;
144+
let include = Select::include_expr(
145+
vec![Field::from("a"), Field::from("b"), Field::from("c")],
146+
ident(),
147+
);
149148
let projection = vec![Field::from("a"), Field::from("b")];
150149
assert_eq!(
151-
&expr_project(&include, &projection).unwrap(),
152-
&(Select::include_expr(projection) as _)
150+
*expr_project(&include, &projection).unwrap(),
151+
*Select::include_expr(projection, ident())
153152
);
154153
}
155154

156155
#[test]
157156
fn project_select_extra_columns() {
158-
let include = Arc::new(Select::include(vec![
159-
Field::from("a"),
160-
Field::from("b"),
161-
Field::from("c"),
162-
])) as _;
157+
let include = Select::include_expr(
158+
vec![Field::from("a"), Field::from("b"), Field::from("c")],
159+
ident(),
160+
);
163161
let projection = vec![Field::from("c"), Field::from("d")];
164162
assert_eq!(
165-
&expr_project(&include, &projection).unwrap(),
166-
&(Select::include_expr(vec![Field::from("c")]) as _)
163+
*expr_project(&include, &projection).unwrap(),
164+
*Select::include_expr(vec![Field::from("c")], ident())
167165
);
168166
}
169167

vortex-expr/src/pruning.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ mod tests {
445445
use crate::pruning::{
446446
convert_to_pruning_expression, stat_column_field, FieldOrIdentity, PruningPredicate,
447447
};
448-
use crate::{and, col, eq, gt, gt_eq, lit, lt, lt_eq, not_eq, or, Identity, Not};
448+
use crate::{and, col, eq, gt, gt_eq, ident, lit, lt, lt_eq, not_eq, or, Not};
449449

450450
#[test]
451451
pub fn pruning_equals() {
@@ -693,7 +693,7 @@ mod tests {
693693

694694
#[test]
695695
fn pruning_identity() {
696-
let column = Identity::new_expr();
696+
let column = ident();
697697
let expr = or(lt(column.clone(), lit(10)), gt(column.clone(), lit(50)));
698698

699699
let expected = HashMap::from([(

0 commit comments

Comments
 (0)