Skip to content

Commit 54c09f6

Browse files
authored
chore: better match_each_XYZ macros (#3401)
1 parent 1b89b2f commit 54c09f6

File tree

83 files changed

+911
-575
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+911
-575
lines changed

encodings/alp/src/alp/compress.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,22 @@ use crate::alp::{ALPArray, ALPFloat};
1515

1616
#[macro_export]
1717
macro_rules! match_each_alp_float_ptype {
18-
($self:expr, | $_:tt $enc:ident | $($body:tt)*) => ({
19-
macro_rules! __with__ {( $_ $enc:ident ) => ( $($body)* )}
18+
($self:expr, | $enc:ident | $body:block) => {{
2019
use vortex_dtype::PType;
2120
use vortex_error::vortex_panic;
2221
let ptype = $self;
2322
match ptype {
24-
PType::F32 => __with__! { f32 },
25-
PType::F64 => __with__! { f64 },
23+
PType::F32 => {
24+
type $enc = f32;
25+
$body
26+
}
27+
PType::F64 => {
28+
type $enc = f64;
29+
$body
30+
}
2631
_ => vortex_panic!("ALP can only encode f32 and f64, got {}", ptype),
2732
}
28-
})
33+
}};
2934
}
3035

3136
pub fn alp_encode(parray: &PrimitiveArray, exponents: Option<Exponents>) -> VortexResult<ALPArray> {
@@ -100,9 +105,9 @@ pub fn decompress(array: &ALPArray) -> VortexResult<PrimitiveArray> {
100105
let validity = encoded.validity().clone();
101106
let ptype = array.dtype().try_into()?;
102107

103-
let decoded = match_each_alp_float_ptype!(ptype, |$T| {
104-
PrimitiveArray::new::<$T>(
105-
<$T>::decode_buffer(encoded.into_buffer_mut(), array.exponents()),
108+
let decoded = match_each_alp_float_ptype!(ptype, |T| {
109+
PrimitiveArray::new::<T>(
110+
<T>::decode_buffer(encoded.into_buffer_mut(), array.exponents()),
106111
validity,
107112
)
108113
});

encodings/alp/src/alp/compute/between.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,16 @@ impl BetweenKernel for ALPVTable {
3030
let nullability =
3131
array.dtype().nullability() | lower.dtype().nullability() | upper.dtype().nullability();
3232

33-
match_each_alp_float_ptype!(array.ptype(), |$F| {
34-
between_impl::<$F>(array, $F::try_from(lower)?, $F::try_from(upper)?, nullability, options)
33+
match_each_alp_float_ptype!(array.ptype(), |F| {
34+
between_impl::<F>(
35+
array,
36+
F::try_from(lower)?,
37+
F::try_from(upper)?,
38+
nullability,
39+
options,
40+
)
3541
})
36-
.map(Some)
42+
.map(Some)
3743
}
3844
}
3945

encodings/alp/src/alp/compute/compare.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,14 @@ impl CompareKernel for ALPVTable {
3030
if let Some(const_scalar) = rhs.as_constant() {
3131
let pscalar = PrimitiveScalar::try_from(&const_scalar)?;
3232

33-
match_each_alp_float_ptype!(pscalar.ptype(), |$T| {
34-
match pscalar.typed_value::<$T>() {
33+
match_each_alp_float_ptype!(pscalar.ptype(), |T| {
34+
match pscalar.typed_value::<T>() {
3535
Some(value) => return alp_scalar_compare(lhs, value, operator),
36-
None => vortex_bail!("Failed to convert scalar {:?} to ALP type {:?}", pscalar, pscalar.ptype()),
36+
None => vortex_bail!(
37+
"Failed to convert scalar {:?} to ALP type {:?}",
38+
pscalar,
39+
pscalar.ptype()
40+
),
3741
}
3842
});
3943
}

encodings/alp/src/alp/ops.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use vortex_array::vtable::OperationsVTable;
22
use vortex_array::{Array, ArrayRef, IntoArray};
3-
use vortex_error::VortexResult;
3+
use vortex_error::{VortexExpect, VortexResult};
44
use vortex_scalar::Scalar;
55

66
use crate::{ALPArray, ALPFloat, ALPVTable, match_each_alp_float_ptype};
@@ -32,12 +32,15 @@ impl OperationsVTable<ALPVTable> for ALPVTable {
3232

3333
let encoded_val = array.encoded().scalar_at(index)?;
3434

35-
Ok(match_each_alp_float_ptype!(array.ptype(), |$T| {
36-
let encoded_val: <$T as ALPFloat>::ALPInt = encoded_val.as_ref().try_into().unwrap();
37-
Scalar::primitive(<$T as ALPFloat>::decode_single(
38-
encoded_val,
39-
array.exponents(),
40-
), array.dtype().nullability())
35+
Ok(match_each_alp_float_ptype!(array.ptype(), |T| {
36+
let encoded_val: <T as ALPFloat>::ALPInt = encoded_val
37+
.as_ref()
38+
.try_into()
39+
.vortex_expect("invalid ALPInt");
40+
Scalar::primitive(
41+
<T as ALPFloat>::decode_single(encoded_val, array.exponents()),
42+
array.dtype().nullability(),
43+
)
4144
}))
4245
}
4346
}

encodings/alp/src/alp_rd/mod.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,7 @@ impl RDEncoder {
167167
///
168168
/// Each value will be split into a left and right component, which are compressed individually.
169169
pub fn encode(&self, array: &PrimitiveArray) -> ALPRDArray {
170-
match_each_alp_float_ptype!(array.ptype(), |$P| {
171-
self.encode_generic::<$P>(array)
172-
})
170+
match_each_alp_float_ptype!(array.ptype(), |P| { self.encode_generic::<P>(array) })
173171
}
174172

175173
fn encode_generic<T>(&self, array: &PrimitiveArray) -> ALPRDArray
@@ -288,12 +286,12 @@ pub fn alp_rd_decode<T: ALPRDFloat>(
288286
if let Some(patches) = left_parts_patches {
289287
let indices = patches.indices().to_primitive()?;
290288
let patch_values = patches.values().to_primitive()?;
291-
match_each_integer_ptype!(indices.ptype(), |$T| {
289+
match_each_integer_ptype!(indices.ptype(), |T| {
292290
indices
293-
.as_slice::<$T>()
291+
.as_slice::<T>()
294292
.iter()
295293
.copied()
296-
.map(|idx| idx - patches.offset() as $T)
294+
.map(|idx| idx - patches.offset() as T)
297295
.zip(patch_values.as_slice::<u16>().iter())
298296
.for_each(|(idx, v)| values[idx as usize] = *v);
299297
})

encodings/bytebool/src/compute.rs

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,32 +26,30 @@ impl TakeKernel for ByteBoolVTable {
2626
// have fallible is_valid function.
2727
let arr = match validity {
2828
Mask::AllTrue(_) => {
29-
let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
30-
indices.as_slice::<$I>()
31-
.iter()
32-
.map(|&idx| {
33-
let idx: usize = idx.as_();
34-
bools[idx]
35-
})
36-
.collect::<Vec<_>>()
29+
let bools = match_each_integer_ptype!(indices.ptype(), |I| {
30+
indices
31+
.as_slice::<I>()
32+
.iter()
33+
.map(|&idx| {
34+
let idx: usize = idx.as_();
35+
bools[idx]
36+
})
37+
.collect::<Vec<_>>()
3738
});
3839

3940
ByteBoolArray::from(bools).into_array()
4041
}
4142
Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
4243
Mask::Values(values) => {
43-
let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
44-
indices.as_slice::<$I>()
45-
.iter()
46-
.map(|&idx| {
47-
let idx = idx.as_();
48-
if values.value(idx) {
49-
Some(bools[idx])
50-
} else {
51-
None
52-
}
53-
})
54-
.collect::<Vec<Option<_>>>()
44+
let bools = match_each_integer_ptype!(indices.ptype(), |I| {
45+
indices
46+
.as_slice::<I>()
47+
.iter()
48+
.map(|&idx| {
49+
let idx = idx.as_();
50+
values.value(idx).then(|| bools[idx])
51+
})
52+
.collect::<Vec<Option<_>>>()
5553
});
5654

5755
ByteBoolArray::from(bools).into_array()

encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ fn decimal_value_wrapper_to_primitive(
4646
decimal_value: DecimalValue,
4747
ptype: PType,
4848
) -> Option<ScalarValue> {
49-
match_each_integer_ptype!(ptype, |$P| {
50-
decimal_value_to_primitive::<$P>(decimal_value)
49+
match_each_integer_ptype!(ptype, |P| {
50+
decimal_value_to_primitive::<P>(decimal_value)
5151
})
5252
}
5353

encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,10 @@ impl CanonicalVTable<DecimalBytePartsVTable> for DecimalBytePartsVTable {
121121
// Depending on the decimal type and the min/max of the primitive array we can choose
122122
// the correct buffer size
123123

124-
let res = match_each_signed_integer_ptype!(prim.ptype(), |$P| {
125-
Canonical::Decimal(DecimalArray::new(
126-
prim.buffer::<$P>(),
127-
array.decimal_dtype().clone(),
124+
let res = match_each_signed_integer_ptype!(prim.ptype(), |P| {
125+
Canonical::Decimal(DecimalArray::new(
126+
prim.buffer::<P>(),
127+
*array.decimal_dtype(),
128128
prim.validity().clone(),
129129
))
130130
});
@@ -147,6 +147,7 @@ impl OperationsVTable<DecimalBytePartsVTable> for DecimalBytePartsVTable {
147147
.map(|d| d.to_array())
148148
}
149149

150+
#[allow(clippy::useless_conversion)]
150151
fn scalar_at(array: &DecimalBytePartsArray, index: usize) -> VortexResult<Scalar> {
151152
// TODO(joe): support parts len != 1
152153
assert!(array.lower_parts.is_empty());
@@ -155,8 +156,12 @@ impl OperationsVTable<DecimalBytePartsVTable> for DecimalBytePartsVTable {
155156
// Note. values in msp, can only be signed integers upto size i64.
156157
let primitive_scalar = scalar.as_primitive();
157158
// TODO(joe): extend this to support multiple parts.
158-
let value = match_each_signed_integer_ptype!(primitive_scalar.ptype(), |$P| {
159-
i64::from(primitive_scalar.typed_value::<$P>().vortex_expect("scalar must have correct ptype"))
159+
let value = match_each_signed_integer_ptype!(primitive_scalar.ptype(), |P| {
160+
i64::from(
161+
primitive_scalar
162+
.typed_value::<P>()
163+
.vortex_expect("scalar must have correct ptype"),
164+
)
160165
});
161166
Ok(Scalar::new(
162167
array.dtype.clone(),

encodings/dict/src/array.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,11 @@ impl ValidityVTable<DictVTable> for DictVTable {
145145
AllOr::All => {
146146
let primitive_codes = array.codes().to_primitive()?;
147147
let values_mask = array.values().validity_mask()?;
148-
let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
149-
let codes_slice = primitive_codes.as_slice::<$P>();
148+
let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
149+
let codes_slice = primitive_codes.as_slice::<P>();
150150
BooleanBuffer::collect_bool(array.len(), |idx| {
151-
values_mask.value(codes_slice[idx] as usize)
151+
#[allow(clippy::cast_possible_truncation)]
152+
values_mask.value(codes_slice[idx] as usize)
152153
})
153154
});
154155
Ok(Mask::from_buffer(is_valid_buffer))
@@ -157,10 +158,11 @@ impl ValidityVTable<DictVTable> for DictVTable {
157158
AllOr::Some(validity_buff) => {
158159
let primitive_codes = array.codes().to_primitive()?;
159160
let values_mask = array.values().validity_mask()?;
160-
let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
161-
let codes_slice = primitive_codes.as_slice::<$P>();
161+
let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
162+
let codes_slice = primitive_codes.as_slice::<P>();
163+
#[allow(clippy::cast_possible_truncation)]
162164
BooleanBuffer::collect_bool(array.len(), |idx| {
163-
validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
165+
validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
164166
})
165167
});
166168
Ok(Mask::from_buffer(is_valid_buffer))

encodings/dict/src/builders/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ pub fn dict_encoder(
3333
constraints: &DictConstraints,
3434
) -> VortexResult<Box<dyn DictEncoder>> {
3535
let dict_builder: Box<dyn DictEncoder> = if let Some(pa) = array.as_opt::<PrimitiveVTable>() {
36-
match_each_native_ptype!(pa.ptype(), |$P| {
37-
primitive_dict_builder::<$P>(pa.dtype().nullability(), &constraints)
36+
match_each_native_ptype!(pa.ptype(), |P| {
37+
primitive_dict_builder::<P>(pa.dtype().nullability(), constraints)
3838
})
3939
} else if let Some(vbv) = array.as_opt::<VarBinViewVTable>() {
4040
bytes_dict_builder(vbv.dtype().clone(), constraints)

0 commit comments

Comments
 (0)