Skip to content

Commit aae9639

Browse files
committed
CastFn
Signed-off-by: Nicholas Gates <[email protected]>
1 parent a07dc84 commit aae9639

File tree

8 files changed

+140
-64
lines changed

8 files changed

+140
-64
lines changed

vortex-array/src/arrays/bool/vtable/operator.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ impl OperatorVTable<BoolVTable> for BoolVTable {
4949
pub struct BoolMaskedValidityRule;
5050

5151
impl ArrayParentReduceRule<Exact<BoolVTable>, Exact<MaskedVTable>> for BoolMaskedValidityRule {
52+
fn child(&self) -> Exact<BoolVTable> {
53+
Exact::from(&BoolVTable)
54+
}
55+
56+
fn parent(&self) -> Exact<MaskedVTable> {
57+
Exact::from(&MaskedVTable)
58+
}
59+
5260
fn reduce_parent(
5361
&self,
5462
array: &BoolArray,

vortex-array/src/arrays/decimal/vtable/operator.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ pub struct DecimalMaskedValidityRule;
5858
impl ArrayParentReduceRule<Exact<DecimalVTable>, Exact<MaskedVTable>>
5959
for DecimalMaskedValidityRule
6060
{
61+
fn child(&self) -> Exact<DecimalVTable> {
62+
Exact::from(&DecimalVTable)
63+
}
64+
65+
fn parent(&self) -> Exact<MaskedVTable> {
66+
Exact::from(&MaskedVTable)
67+
}
68+
6169
fn reduce_parent(
6270
&self,
6371
array: &DecimalArray,

vortex-array/src/arrays/primitive/vtable/operator.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ pub struct PrimitiveMaskedValidityRule;
5656
impl ArrayParentReduceRule<Exact<PrimitiveVTable>, Exact<MaskedVTable>>
5757
for PrimitiveMaskedValidityRule
5858
{
59+
fn child(&self) -> Exact<PrimitiveVTable> {
60+
Exact::from(&PrimitiveVTable)
61+
}
62+
63+
fn parent(&self) -> Exact<MaskedVTable> {
64+
Exact::from(&MaskedVTable)
65+
}
66+
5967
fn reduce_parent(
6068
&self,
6169
array: &PrimitiveArray,

vortex-array/src/optimizer/mod.rs

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,20 @@
44
use std::sync::Arc;
55

66
use vortex_error::VortexResult;
7-
use vortex_session::SessionVar;
87
use vortex_utils::aliases::hash_map::HashMap;
98

109
use crate::Array;
1110
use crate::ArrayVisitor;
1211
use crate::array::ArrayRef;
1312
use crate::optimizer::rules::AnyArray;
1413
use crate::optimizer::rules::ArrayParentReduceRule;
15-
use crate::optimizer::rules::ArrayParentReduceRuleAdapter;
1614
use crate::optimizer::rules::ArrayReduceRule;
17-
use crate::optimizer::rules::ArrayReduceRuleAdapter;
1815
use crate::optimizer::rules::DynArrayParentReduceRule;
1916
use crate::optimizer::rules::DynArrayReduceRule;
2017
use crate::optimizer::rules::MatchKey;
2118
use crate::optimizer::rules::Matcher;
19+
use crate::optimizer::rules::ParentReduceRuleAdapter;
20+
use crate::optimizer::rules::ReduceRuleAdapter;
2221

2322
pub mod rules;
2423

@@ -35,8 +34,6 @@ pub struct ArrayOptimizer {
3534
reduce_rules: HashMap<MatchKey, Vec<Arc<dyn DynArrayReduceRule>>>,
3635
/// Parent reduce rules for specific parent types, indexed by (child, parent)
3736
parent_rules: HashMap<(MatchKey, MatchKey), Vec<Arc<dyn DynArrayParentReduceRule>>>,
38-
/// Wildcard parent rules (match any parent), indexed by child only
39-
any_parent_rules: HashMap<MatchKey, Vec<Arc<dyn DynArrayParentReduceRule>>>,
4037
}
4138

4239
impl ArrayOptimizer {
@@ -167,9 +164,10 @@ impl ArrayOptimizer {
167164
M: Matcher,
168165
R: ArrayReduceRule<M> + 'static,
169166
{
170-
let adapter = ArrayReduceRuleAdapter::new(rule);
167+
let key = rule.matcher().key();
168+
let adapter = ReduceRuleAdapter::new(rule);
171169
self.reduce_rules
172-
.entry(M::key())
170+
.entry(key)
173171
.or_default()
174172
.push(Arc::new(adapter));
175173
}
@@ -181,9 +179,10 @@ impl ArrayOptimizer {
181179
Parent: Matcher,
182180
R: ArrayParentReduceRule<Child, Parent> + 'static,
183181
{
184-
let adapter = ArrayParentReduceRuleAdapter::new(rule);
182+
let key = (rule.child().key(), rule.parent().key());
183+
let adapter = ParentReduceRuleAdapter::new(rule);
185184
self.parent_rules
186-
.entry((Child::key(), Parent::key()))
185+
.entry(key)
187186
.or_default()
188187
.push(Arc::new(adapter));
189188
}
@@ -194,9 +193,10 @@ impl ArrayOptimizer {
194193
Child: Matcher,
195194
R: ArrayParentReduceRule<Child, AnyArray> + 'static,
196195
{
197-
let adapter = ArrayParentReduceRuleAdapter::new(rule);
198-
self.any_parent_rules
199-
.entry(Child::key())
196+
let key = (rule.child().key(), MatchKey::Any);
197+
let adapter = ParentReduceRuleAdapter::new(rule);
198+
self.parent_rules
199+
.entry(key)
200200
.or_default()
201201
.push(Arc::new(adapter));
202202
}
@@ -206,12 +206,13 @@ impl ArrayOptimizer {
206206
where
207207
F: FnOnce(&mut dyn Iterator<Item = &dyn DynArrayReduceRule>) -> R,
208208
{
209-
f(&mut self
210-
.reduce_rules
211-
.get(&MatchKey::Type(array.encoding().as_any().type_id()))
209+
let exact = self.reduce_rules.get(&MatchKey::Array(array.encoding_id()));
210+
let any = self.reduce_rules.get(&MatchKey::Any);
211+
f(&mut exact
212212
.iter()
213+
.chain(any.iter())
213214
.flat_map(|v| v.iter())
214-
.map(|arc| arc.as_ref()))
215+
.map(|v| v.as_ref()))
215216
}
216217

217218
/// Execute a callback with all parent reduce rules for a given child and parent encoding ID.
@@ -226,20 +227,20 @@ impl ArrayOptimizer {
226227
where
227228
F: FnOnce(&mut dyn Iterator<Item = &dyn DynArrayParentReduceRule>) -> R,
228229
{
229-
let specific_entry = parent.and_then(|parent| {
230+
let exact = parent.and_then(|parent| {
230231
self.parent_rules.get(&(
231-
MatchKey::Type(child.encoding().as_any().type_id()),
232-
MatchKey::Type(parent.encoding().as_any().type_id()),
232+
MatchKey::Array(child.encoding_id()),
233+
MatchKey::Array(parent.encoding_id()),
233234
))
234235
});
235-
let wildcard_entry = self
236-
.any_parent_rules
237-
.get(&MatchKey::Type(child.encoding().as_any().type_id()));
236+
let any = self
237+
.parent_rules
238+
.get(&(MatchKey::Array(child.encoding_id()), MatchKey::Any));
238239

239-
f(&mut specific_entry
240+
f(&mut exact
240241
.iter()
242+
.chain(any.iter())
241243
.flat_map(|v| v.iter())
242-
.chain(wildcard_entry.iter().flat_map(|v| v.iter()))
243244
.map(|arc| arc.as_ref()))
244245
}
245246
}

0 commit comments

Comments
 (0)