Skip to content

Commit 23bbef6

Browse files
committed
Clarify docs in session accessors
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 46ebd4c commit 23bbef6

File tree

13 files changed

+142
-20
lines changed

13 files changed

+142
-20
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

encodings/runend/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ vortex-dtype = { workspace = true }
2525
vortex-error = { workspace = true }
2626
vortex-mask = { workspace = true }
2727
vortex-scalar = { workspace = true }
28+
vortex-session = { workspace = true }
2829

2930
[lints]
3031
workspace = true

encodings/runend/src/lib.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub mod compress;
1111
mod compute;
1212
mod iter;
1313
mod ops;
14+
mod rules;
1415

1516
#[doc(hidden)]
1617
pub mod _benchmarking {
@@ -23,11 +24,17 @@ pub mod _benchmarking {
2324
use vortex_array::ArrayBufferVisitor;
2425
use vortex_array::ArrayChildVisitor;
2526
use vortex_array::Canonical;
27+
use vortex_array::session::ArraySession;
28+
use vortex_array::session::ArraySessionExt;
29+
use vortex_array::vtable::ArrayVTableExt;
2630
use vortex_array::vtable::EncodeVTable;
2731
use vortex_array::vtable::VisitorVTable;
2832
use vortex_error::VortexResult;
33+
use vortex_session::SessionExt;
34+
use vortex_session::VortexSession;
2935

3036
use crate::compress::runend_encode;
37+
use crate::rules::RunEndScalarFnRule;
3138

3239
impl EncodeVTable<RunEndVTable> for RunEndVTable {
3340
fn encode(
@@ -59,6 +66,15 @@ impl VisitorVTable<RunEndVTable> for RunEndVTable {
5966
}
6067
}
6168

69+
/// Initialize run-end encoding in the given session.
70+
pub fn initialize(session: &mut VortexSession) {
71+
session.arrays().register(RunEndVTable.as_vtable());
72+
session
73+
.get_mut::<ArraySession>()
74+
.optimizer_mut()
75+
.register_parent_rule(RunEndScalarFnRule);
76+
}
77+
6278
#[cfg(test)]
6379
mod tests {
6480
use vortex_array::ProstMetadata;

encodings/runend/src/rules.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_array::ArrayRef;
5+
use vortex_array::IntoArray;
6+
use vortex_array::arrays::AnyScalarFn;
7+
use vortex_array::arrays::ConstantArray;
8+
use vortex_array::arrays::ConstantVTable;
9+
use vortex_array::arrays::ScalarFnArray;
10+
use vortex_array::optimizer::rules::ArrayParentReduceRule;
11+
use vortex_array::optimizer::rules::Exact;
12+
use vortex_dtype::DType;
13+
use vortex_error::VortexResult;
14+
15+
use crate::RunEndArray;
16+
use crate::RunEndVTable;
17+
18+
/// A rule to push down scalar functions through run-end encoding into the values array.
19+
///
20+
/// This only works if all other children of the scalar function array are constants.
21+
#[derive(Debug)]
22+
pub(crate) struct RunEndScalarFnRule;
23+
24+
impl ArrayParentReduceRule<Exact<RunEndVTable>, AnyScalarFn> for RunEndScalarFnRule {
25+
fn child(&self) -> Exact<RunEndVTable> {
26+
Exact::from(&RunEndVTable)
27+
}
28+
29+
fn parent(&self) -> AnyScalarFn {
30+
AnyScalarFn
31+
}
32+
33+
fn reduce_parent(
34+
&self,
35+
run_end: &RunEndArray,
36+
parent: &ScalarFnArray,
37+
child_idx: usize,
38+
) -> VortexResult<Option<ArrayRef>> {
39+
for (idx, child) in parent.children().iter().enumerate() {
40+
if idx == child_idx {
41+
// Skip ourselves
42+
continue;
43+
}
44+
45+
if !child.is::<ConstantVTable>() {
46+
// We can only push down if all other children are constants
47+
return Ok(None);
48+
}
49+
}
50+
51+
// TODO(ngates): relax this constraint and implement run-end decoding for all vector types.
52+
if !matches!(parent.dtype(), DType::Bool(_) | DType::Primitive(..)) {
53+
return Ok(None);
54+
}
55+
56+
let values_len = run_end.values().len();
57+
let mut new_children = parent.children();
58+
for (idx, child) in new_children.iter_mut().enumerate() {
59+
if idx == child_idx {
60+
// Replace ourselves with run end values
61+
*child = run_end.values().clone();
62+
continue;
63+
}
64+
65+
// Replace other children with their constant scalar value with length adjusted
66+
// to the length of the run end values.
67+
let constant = child.as_::<ConstantVTable>();
68+
*child = ConstantArray::new(constant.scalar().clone(), values_len).into_array();
69+
}
70+
71+
let new_values =
72+
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, values_len)?
73+
.into_array();
74+
75+
Ok(Some(
76+
RunEndArray::try_new_offset_length(
77+
run_end.ends().clone(),
78+
new_values,
79+
run_end.offset(),
80+
run_end.len(),
81+
)?
82+
.into_array(),
83+
))
84+
}
85+
}

vortex-array/src/arrays/scalar_fn/array.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,9 @@ impl ScalarFnArray {
4444
stats: Default::default(),
4545
})
4646
}
47+
48+
/// Get the scalar function bound to this array.
49+
pub fn scalar_fn(&self) -> &ScalarFn {
50+
&self.scalar_fn
51+
}
4752
}

vortex-array/src/arrays/scalar_fn/metadata.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ pub struct ScalarFnMetadata {
1717
// Array tree display wrongly uses debug...
1818
impl Debug for ScalarFnMetadata {
1919
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
20-
write!(f, "{}", self.scalar_fn)
20+
write!(f, "{}", self.scalar_fn.options())
2121
}
2222
}

vortex-array/src/arrays/scalar_fn/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
mod array;
55
mod kernel;
66
mod metadata;
7-
mod rules;
7+
pub(crate) mod rules;
88
mod vtable;
99

1010
pub use array::*;

vortex-array/src/arrays/scalar_fn/rules.rs

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_error::VortexExpect;
54
use vortex_error::VortexResult;
65
use vortex_scalar::Scalar;
76
use vortex_vector::Datum;
7+
use vortex_vector::VectorOps;
88
use vortex_vector::scalar_matches_dtype;
99

1010
use crate::Array;
@@ -37,17 +37,23 @@ impl ArrayReduceRule<AnyScalarFn> for ScalarFnConstantRule {
3737
.collect();
3838
let input_dtypes = array.children.iter().map(|c| c.dtype().clone()).collect();
3939

40-
let result = array
41-
.scalar_fn
42-
.execute(ExecutionArgs {
43-
datums: input_datums,
44-
dtypes: input_dtypes,
45-
row_count: array.len,
46-
return_dtype: array.dtype.clone(),
47-
})?
48-
.into_scalar()
49-
.vortex_expect("Scalar inputs should produce scalar output");
40+
let result = array.scalar_fn.execute(ExecutionArgs {
41+
datums: input_datums,
42+
dtypes: input_dtypes,
43+
row_count: array.len,
44+
return_dtype: array.dtype.clone(),
45+
})?;
5046

47+
let result = match result {
48+
Datum::Scalar(s) => s,
49+
Datum::Vector(v) => {
50+
log::warn!(
51+
"Scalar function {} returned vector from execution over all scalar inputs",
52+
array.scalar_fn
53+
);
54+
v.scalar_at(0)
55+
}
56+
};
5157
assert!(scalar_matches_dtype(&result, &array.dtype));
5258

5359
Ok(Some(

vortex-array/src/expr/exprs/binary.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ impl VTable for Binary {
136136

137137
match op {
138138
Operator::And => {
139+
// FIXME(ngates): implement logical compute over datums
139140
let lhs = lhs.ensure_vector(args.row_count).into_bool().into_arrow()?;
140141
let rhs = rhs.ensure_vector(args.row_count).into_bool().into_arrow()?;
141142
return Ok(Datum::Vector(
@@ -145,6 +146,7 @@ impl VTable for Binary {
145146
));
146147
}
147148
Operator::Or => {
149+
// FIXME(ngates): implement logical compute over datums
148150
let lhs = lhs.ensure_vector(args.row_count).into_bool().into_arrow()?;
149151
let rhs = rhs.ensure_vector(args.row_count).into_bool().into_arrow()?;
150152
return Ok(Datum::Vector(

vortex-array/src/session/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::arrays::StructGetItemRule;
2323
use crate::arrays::StructVTable;
2424
use crate::arrays::VarBinVTable;
2525
use crate::arrays::VarBinViewVTable;
26+
use crate::arrays::rules::ScalarFnConstantRule;
2627
use crate::optimizer::ArrayOptimizer;
2728
use crate::vtable::ArrayVTable;
2829
use crate::vtable::ArrayVTableExt;
@@ -94,6 +95,9 @@ impl Default for ArraySession {
9495
};
9596

9697
let optimizer = session.optimizer_mut();
98+
99+
optimizer.register_reduce_rule(ScalarFnConstantRule);
100+
97101
optimizer.register_parent_rule(BoolMaskedValidityRule);
98102
optimizer.register_parent_rule(PrimitiveMaskedValidityRule);
99103
optimizer.register_parent_rule(DecimalMaskedValidityRule);

0 commit comments

Comments
 (0)