Skip to content

Commit 5cf5ff5

Browse files
feat[vortex-array] add an array session rule registry (#5414)
Signed-off-by: Joe Isaacs <[email protected]> Signed-off-by: Nicholas Gates <[email protected]> Co-authored-by: Nicholas Gates <[email protected]> Co-authored-by: Nicholas Gates <[email protected]>
1 parent 79d3984 commit 5cf5ff5

File tree

30 files changed

+1041
-216
lines changed

30 files changed

+1041
-216
lines changed

vortex-array/src/array/mod.rs

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
pub mod display;
55
mod operator;
6+
pub mod session;
7+
pub mod transform;
68
mod visitor;
79

810
use std::any::Any;
@@ -15,7 +17,7 @@ pub use operator::*;
1517
pub use visitor::*;
1618
use vortex_buffer::ByteBuffer;
1719
use vortex_dtype::{DType, Nullability};
18-
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic};
20+
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_panic};
1921
use vortex_mask::Mask;
2022
use vortex_scalar::Scalar;
2123

@@ -617,18 +619,9 @@ impl<V: VTable> Array for ArrayAdapter<V> {
617619
}
618620
}
619621

620-
let metadata = self.metadata()?.ok_or_else(|| {
621-
vortex_err!("Cannot replace children for arrays that do not support serialization")
622-
})?;
623-
624622
// Replace the children of the array by re-building the array from parts.
625-
self.encoding().build(
626-
self.dtype(),
627-
self.len(),
628-
&metadata,
629-
&self.buffers(),
630-
&ReplacementChildren { children },
631-
)
623+
self.encoding()
624+
.with_children(self, &ReplacementChildren { children })
632625
}
633626

634627
fn invoke(

vortex-array/src/array/operator.rs

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,6 @@ pub trait ArrayOperator: 'static + Send + Sync {
2727
/// If the array's implementation returns an invalid vector (wrong length, wrong type, etc.).
2828
fn execute_batch(&self, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector>;
2929

30-
/// Optimize the array by running the optimization rules.
31-
fn reduce(&self) -> VortexResult<Option<ArrayRef>>;
32-
33-
/// Optimize the array by pushing down a parent array.
34-
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
35-
3630
/// Returns the array as a pipeline node, if supported.
3731
fn as_pipelined(&self) -> Option<&dyn PipelinedNode>;
3832

@@ -49,14 +43,6 @@ impl ArrayOperator for Arc<dyn Array> {
4943
self.as_ref().execute_batch(ctx)
5044
}
5145

52-
fn reduce(&self) -> VortexResult<Option<ArrayRef>> {
53-
self.as_ref().reduce()
54-
}
55-
56-
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>> {
57-
self.as_ref().reduce_parent(parent, child_idx)
58-
}
59-
6046
fn as_pipelined(&self) -> Option<&dyn PipelinedNode> {
6147
self.as_ref().as_pipelined()
6248
}
@@ -88,14 +74,6 @@ impl<V: VTable> ArrayOperator for ArrayAdapter<V> {
8874
Ok(vector)
8975
}
9076

91-
fn reduce(&self) -> VortexResult<Option<ArrayRef>> {
92-
<V::OperatorVTable as OperatorVTable<V>>::reduce(&self.0)
93-
}
94-
95-
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>> {
96-
<V::OperatorVTable as OperatorVTable<V>>::reduce_parent(&self.0, parent, child_idx)
97-
}
98-
9977
fn as_pipelined(&self) -> Option<&dyn PipelinedNode> {
10078
<V::OperatorVTable as OperatorVTable<V>>::pipeline_node(&self.0)
10179
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
pub mod rewrite;
5+
6+
pub use rewrite::ArrayRewriteRuleRegistry;
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::marker::PhantomData;
5+
use std::sync::Arc;
6+
7+
use vortex_error::VortexResult;
8+
use vortex_utils::aliases::dash_map::DashMap;
9+
10+
use crate::EncodingId;
11+
use crate::array::ArrayRef;
12+
use crate::array::transform::context::ArrayRuleContext;
13+
use crate::array::transform::rules::{
14+
AnyParent, ArrayParentMatcher, ArrayParentReduceRule, ArrayReduceRule,
15+
};
16+
use crate::vtable::VTable;
17+
18+
/// Dynamic trait for array reduce rules
19+
pub trait DynArrayReduceRule: Send + Sync {
20+
fn reduce(&self, array: &ArrayRef, ctx: &ArrayRuleContext) -> VortexResult<Option<ArrayRef>>;
21+
}
22+
23+
/// Dynamic trait for array parent reduce rules
24+
pub trait DynArrayParentReduceRule: Send + Sync {
25+
fn reduce_parent(
26+
&self,
27+
array: &ArrayRef,
28+
parent: &ArrayRef,
29+
child_idx: usize,
30+
ctx: &ArrayRuleContext,
31+
) -> VortexResult<Option<ArrayRef>>;
32+
}
33+
34+
/// Adapter for ArrayReduceRule
35+
struct ArrayReduceRuleAdapter<V: VTable, R> {
36+
rule: R,
37+
_phantom: PhantomData<V>,
38+
}
39+
40+
/// Adapter for ArrayParentReduceRule
41+
struct ArrayParentReduceRuleAdapter<Child: VTable, Parent: ArrayParentMatcher, R> {
42+
rule: R,
43+
_phantom: PhantomData<(Child, Parent)>,
44+
}
45+
46+
impl<V, R> DynArrayReduceRule for ArrayReduceRuleAdapter<V, R>
47+
where
48+
V: VTable,
49+
R: ArrayReduceRule<V>,
50+
{
51+
fn reduce(&self, array: &ArrayRef, ctx: &ArrayRuleContext) -> VortexResult<Option<ArrayRef>> {
52+
let Some(view) = array.as_opt::<V>() else {
53+
return Ok(None);
54+
};
55+
self.rule.reduce(view, ctx)
56+
}
57+
}
58+
59+
impl<Child, Parent, R> DynArrayParentReduceRule for ArrayParentReduceRuleAdapter<Child, Parent, R>
60+
where
61+
Child: VTable,
62+
Parent: ArrayParentMatcher,
63+
R: ArrayParentReduceRule<Child, Parent>,
64+
{
65+
fn reduce_parent(
66+
&self,
67+
array: &ArrayRef,
68+
parent: &ArrayRef,
69+
child_idx: usize,
70+
ctx: &ArrayRuleContext,
71+
) -> VortexResult<Option<ArrayRef>> {
72+
let Some(view) = array.as_opt::<Child>() else {
73+
return Ok(None);
74+
};
75+
let Some(parent_view) = Parent::try_match(parent) else {
76+
return Ok(None);
77+
};
78+
self.rule.reduce_parent(view, parent_view, child_idx, ctx)
79+
}
80+
}
81+
82+
/// Inner struct that holds all the rule registries.
83+
/// Wrapped in a single Arc by ArrayRewriteRuleRegistry for efficient cloning.
84+
#[derive(Default)]
85+
struct ArrayRewriteRuleRegistryInner {
86+
/// Reduce rules indexed by encoding ID
87+
reduce_rules: DashMap<EncodingId, Vec<Arc<dyn DynArrayReduceRule>>>,
88+
/// Parent reduce rules for specific parent types, indexed by (child_id, parent_id)
89+
parent_rules: DashMap<(EncodingId, EncodingId), Vec<Arc<dyn DynArrayParentReduceRule>>>,
90+
/// Wildcard parent rules (match any parent), indexed by child_id only
91+
any_parent_rules: DashMap<EncodingId, Vec<Arc<dyn DynArrayParentReduceRule>>>,
92+
}
93+
94+
impl std::fmt::Debug for ArrayRewriteRuleRegistryInner {
95+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96+
f.debug_struct("ArrayRewriteRuleRegistryInner")
97+
.field(
98+
"reduce_rules",
99+
&format!("{} encodings", self.reduce_rules.len()),
100+
)
101+
.field(
102+
"parent_rules",
103+
&format!("{} pairs", self.parent_rules.len()),
104+
)
105+
.field(
106+
"any_parent_rules",
107+
&format!("{} encodings", self.any_parent_rules.len()),
108+
)
109+
.finish()
110+
}
111+
}
112+
113+
/// Registry of array rewrite rules.
114+
///
115+
/// Stores rewrite rules indexed by the encoding ID they apply to.
116+
#[derive(Clone, Debug)]
117+
pub struct ArrayRewriteRuleRegistry {
118+
inner: Arc<ArrayRewriteRuleRegistryInner>,
119+
}
120+
121+
impl Default for ArrayRewriteRuleRegistry {
122+
fn default() -> Self {
123+
Self {
124+
inner: Arc::new(ArrayRewriteRuleRegistryInner::default()),
125+
}
126+
}
127+
}
128+
129+
impl ArrayRewriteRuleRegistry {
130+
pub fn new() -> Self {
131+
Self::default()
132+
}
133+
134+
/// Register a reduce rule for a specific array encoding.
135+
pub fn register_reduce_rule<V, R>(&self, encoding: &V::Encoding, rule: R)
136+
where
137+
V: VTable,
138+
R: 'static,
139+
R: ArrayReduceRule<V>,
140+
{
141+
let adapter = ArrayReduceRuleAdapter {
142+
rule,
143+
_phantom: PhantomData,
144+
};
145+
let encoding_id = V::id(encoding);
146+
self.inner
147+
.reduce_rules
148+
.entry(encoding_id)
149+
.or_default()
150+
.push(Arc::new(adapter));
151+
}
152+
153+
/// Register a parent rule for a specific parent type.
154+
pub fn register_parent_rule<Child, Parent, R>(
155+
&self,
156+
child_encoding: &Child::Encoding,
157+
parent_encoding: &Parent::Encoding,
158+
rule: R,
159+
) where
160+
Child: VTable,
161+
Parent: VTable,
162+
R: 'static,
163+
R: ArrayParentReduceRule<Child, Parent>,
164+
{
165+
let adapter = ArrayParentReduceRuleAdapter {
166+
rule,
167+
_phantom: PhantomData,
168+
};
169+
let child_id = Child::id(child_encoding);
170+
let parent_id = Parent::id(parent_encoding);
171+
self.inner
172+
.parent_rules
173+
.entry((child_id, parent_id))
174+
.or_default()
175+
.push(Arc::new(adapter));
176+
}
177+
178+
/// Register a parent rule that matches ANY parent type (wildcard).
179+
pub fn register_any_parent_rule<Child, R>(&self, child_encoding: &Child::Encoding, rule: R)
180+
where
181+
Child: VTable,
182+
R: 'static,
183+
R: ArrayParentReduceRule<Child, AnyParent>,
184+
{
185+
let adapter = ArrayParentReduceRuleAdapter {
186+
rule,
187+
_phantom: PhantomData,
188+
};
189+
let child_id = Child::id(child_encoding);
190+
self.inner
191+
.any_parent_rules
192+
.entry(child_id)
193+
.or_default()
194+
.push(Arc::new(adapter));
195+
}
196+
197+
/// Execute a callback with all reduce rules for a given encoding ID.
198+
pub(crate) fn with_reduce_rules<F, R>(&self, id: &EncodingId, f: F) -> R
199+
where
200+
F: FnOnce(&mut dyn Iterator<Item = &dyn DynArrayReduceRule>) -> R,
201+
{
202+
f(&mut self
203+
.inner
204+
.reduce_rules
205+
.get(id)
206+
.iter()
207+
.flat_map(|v| v.value())
208+
.map(|arc| arc.as_ref()))
209+
}
210+
211+
/// Execute a callback with all parent reduce rules for a given child and parent encoding ID.
212+
///
213+
/// Returns rules from both specific parent rules (if parent_id provided) and "any parent" wildcard rules.
214+
pub(crate) fn with_parent_rules<F, R>(
215+
&self,
216+
child_id: &EncodingId,
217+
parent_id: Option<&EncodingId>,
218+
f: F,
219+
) -> R
220+
where
221+
F: FnOnce(&mut dyn Iterator<Item = &dyn DynArrayParentReduceRule>) -> R,
222+
{
223+
let specific_entry = parent_id.and_then(|pid| {
224+
self.inner
225+
.parent_rules
226+
.get(&(child_id.clone(), pid.clone()))
227+
});
228+
let wildcard_entry = self.inner.any_parent_rules.get(child_id);
229+
230+
f(&mut specific_entry
231+
.iter()
232+
.flat_map(|v| v.value())
233+
.chain(wildcard_entry.iter().flat_map(|v| v.value()))
234+
.map(|arc| arc.as_ref()))
235+
}
236+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use crate::expr::transform::ExprOptimizer;
5+
6+
/// Rule context for array rewrite rules
7+
///
8+
/// Provides access to the expression optimizer for optimizing expressions
9+
/// embedded in arrays. Note that dtype is not included since arrays already
10+
/// have a dtype that can be accessed directly.
11+
#[derive(Debug, Clone)]
12+
pub struct ArrayRuleContext {
13+
expr_optimizer: ExprOptimizer,
14+
}
15+
16+
impl ArrayRuleContext {
17+
pub fn new(expr_optimizer: ExprOptimizer) -> Self {
18+
Self { expr_optimizer }
19+
}
20+
21+
pub fn expr_optimizer(&self) -> &ExprOptimizer {
22+
&self.expr_optimizer
23+
}
24+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
pub mod context;
5+
pub mod optimizer;
6+
pub mod rules;
7+
#[cfg(test)]
8+
mod tests;
9+
10+
pub use context::ArrayRuleContext;
11+
pub use optimizer::ArrayOptimizer;
12+
pub use rules::{AnyParent, ArrayParentMatcher, ArrayParentReduceRule, ArrayReduceRule};

0 commit comments

Comments
 (0)