Skip to content

Commit 582a975

Browse files
committed
rename as_primitive -> cast
Signed-off-by: Joe Isaacs <[email protected]>
1 parent f93fe4b commit 582a975

File tree

10 files changed

+76
-84
lines changed

10 files changed

+76
-84
lines changed

encodings/sequence/src/array.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ impl SequenceArray {
118118
let len_t = <P>::from_usize(length - 1)
119119
.ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
120120

121-
let base = base.as_primitive::<P>();
122-
let multiplier = multiplier.as_primitive::<P>();
121+
let base = base.cast::<P>();
122+
let multiplier = multiplier.cast::<P>();
123123

124124
let last = len_t
125125
.checked_mul(multiplier)
@@ -133,8 +133,8 @@ impl SequenceArray {
133133
assert!(idx < self.length, "index_value({idx}): index out of bounds");
134134

135135
match_each_native_ptype!(self.ptype(), |P| {
136-
let base = self.base.as_primitive::<P>();
137-
let multiplier = self.multiplier.as_primitive::<P>();
136+
let base = self.base.cast::<P>();
137+
let multiplier = self.multiplier.cast::<P>();
138138
let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
139139

140140
PValue::from(value)
@@ -206,8 +206,8 @@ impl ArrayVTable<SequenceVTable> for SequenceVTable {
206206
impl CanonicalVTable<SequenceVTable> for SequenceVTable {
207207
fn canonicalize(array: &SequenceArray) -> Canonical {
208208
let prim = match_each_native_ptype!(array.ptype(), |P| {
209-
let base = array.base().as_primitive::<P>();
210-
let multiplier = array.multiplier().as_primitive::<P>();
209+
let base = array.base().cast::<P>();
210+
let multiplier = array.multiplier().cast::<P>();
211211
let values = BufferMut::from_iter(
212212
(0..array.len())
213213
.map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),

encodings/sequence/src/compute/compare.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ pub(crate) fn find_intersection_scalar(
6969
intercept: PValue,
7070
) -> Option<usize> {
7171
match_each_integer_ptype!(base.ptype(), |P| {
72-
let intercept = intercept.as_primitive::<P>();
72+
let intercept = intercept.cast::<P>();
7373

74-
let base = base.as_primitive::<P>();
75-
let multiplier = multiplier.as_primitive::<P>();
74+
let base = base.cast::<P>();
75+
let multiplier = multiplier.cast::<P>();
7676

7777
find_intersection(base, multiplier, len, intercept)
7878
})

encodings/sequence/src/compute/filter.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ impl FilterKernel for SequenceVTable {
1616
fn filter(&self, array: &SequenceArray, selection_mask: &Mask) -> VortexResult<ArrayRef> {
1717
let validity = Validity::from(array.dtype().nullability());
1818
match_each_native_ptype!(array.ptype(), |P| {
19-
let mul = array.multiplier().as_primitive::<P>();
20-
let base = array.base().as_primitive::<P>();
19+
let mul = array.multiplier().cast::<P>();
20+
let base = array.base().cast::<P>();
2121
Ok(filter_impl(mul, base, selection_mask, validity))
2222
})
2323
}

encodings/sequence/src/compute/is_sorted.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,12 @@ use crate::{SequenceArray, SequenceVTable};
1212
impl IsSortedKernel for SequenceVTable {
1313
fn is_sorted(&self, array: &SequenceArray) -> VortexResult<Option<bool>> {
1414
let m = array.multiplier();
15-
match_each_native_ptype!(m.ptype(), |P| {
16-
Ok(Some(m.as_primitive::<P>() >= zero::<P>()))
17-
})
15+
match_each_native_ptype!(m.ptype(), |P| { Ok(Some(m.cast::<P>() >= zero::<P>())) })
1816
}
1917

2018
fn is_strict_sorted(&self, array: &SequenceArray) -> VortexResult<Option<bool>> {
2119
let m = array.multiplier();
22-
match_each_native_ptype!(m.ptype(), |P| {
23-
Ok(Some(m.as_primitive::<P>() > zero::<P>()))
24-
})
20+
match_each_native_ptype!(m.ptype(), |P| { Ok(Some(m.cast::<P>() > zero::<P>())) })
2521
}
2622
}
2723

encodings/sequence/src/compute/take.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ impl TakeKernel for SequenceVTable {
2626
Ok(match_each_integer_ptype!(indices.ptype(), |T| {
2727
let indices = indices.as_slice::<T>();
2828
match_each_native_ptype!(array.ptype(), |S| {
29-
let mul = array.multiplier().as_primitive::<S>();
30-
let base = array.base().as_primitive::<S>();
29+
let mul = array.multiplier().cast::<S>();
30+
let base = array.base().cast::<S>();
3131
take(mul, base, indices, mask, result_nullability)
3232
})
3333
}))

encodings/sequence/src/operator.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ impl OperatorVTable<SequenceVTable> for SequenceVTable {
2424
let selection = ctx.bind_selection(array.len(), selection)?;
2525

2626
Ok(match_each_native_ptype!(array.ptype(), |T| {
27-
if array.multiplier().as_primitive::<T>() == <T as One>::one() {
27+
if array.multiplier().cast::<T>() == <T as One>::one() {
2828
Box::new(SequenceKernel::<T> {
29-
base: array.base().as_primitive::<T>(),
29+
base: array.base().cast::<T>(),
3030
selection,
3131
})
3232
} else {
3333
Box::new(MultiplierSequenceKernel::<T> {
34-
base: array.base().as_primitive::<T>(),
35-
multiplier: array.multiplier().as_primitive::<T>(),
34+
base: array.base().cast::<T>(),
35+
multiplier: array.multiplier().cast::<T>(),
3636
selection,
3737
})
3838
}

vortex-array/src/builders/primitive.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ impl<T: NativePType> ArrayBuilder for PrimitiveBuilder<T> {
148148

149149
let primitive_scalar = PrimitiveScalar::try_from(scalar)?;
150150
match primitive_scalar.pvalue() {
151-
Some(pv) => self.append_value(pv.as_primitive::<T>()),
151+
Some(pv) => self.append_value(pv.cast::<T>()),
152152
None => self.append_null(),
153153
}
154154

vortex-dtype/src/f16.rs

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use half::f16;
5-
use num_traits::FromPrimitive;
5+
use num_traits::{FromPrimitive, ToPrimitive};
66

77
/// A trait for types that can be created from primitive values, including f16.
88
///
@@ -12,7 +12,7 @@ pub trait FromPrimitiveOrF16: FromPrimitive {
1212
fn from_f16(v: f16) -> Option<Self>;
1313
}
1414

15-
macro_rules! from_primitive_or_f16_for_non_floating_point {
15+
macro_rules! from_primitive_or_f16_for_signed {
1616
($T:ty) => {
1717
impl FromPrimitiveOrF16 for $T {
1818
fn from_f16(_: f16) -> Option<Self> {
@@ -22,15 +22,25 @@ macro_rules! from_primitive_or_f16_for_non_floating_point {
2222
};
2323
}
2424

25-
from_primitive_or_f16_for_non_floating_point!(usize);
26-
from_primitive_or_f16_for_non_floating_point!(u8);
27-
from_primitive_or_f16_for_non_floating_point!(u16);
28-
from_primitive_or_f16_for_non_floating_point!(u32);
29-
from_primitive_or_f16_for_non_floating_point!(u64);
30-
from_primitive_or_f16_for_non_floating_point!(i8);
31-
from_primitive_or_f16_for_non_floating_point!(i16);
32-
from_primitive_or_f16_for_non_floating_point!(i32);
33-
from_primitive_or_f16_for_non_floating_point!(i64);
25+
macro_rules! from_primitive_or_f16_for_unsigned {
26+
($T:ty) => {
27+
impl FromPrimitiveOrF16 for $T {
28+
fn from_f16(value: f16) -> Option<Self> {
29+
value.to_u64().and_then(|v| FromPrimitive::from_u64(v))
30+
}
31+
}
32+
};
33+
}
34+
35+
from_primitive_or_f16_for_unsigned!(usize);
36+
from_primitive_or_f16_for_unsigned!(u8);
37+
from_primitive_or_f16_for_unsigned!(u16);
38+
from_primitive_or_f16_for_unsigned!(u32);
39+
from_primitive_or_f16_for_unsigned!(u64);
40+
from_primitive_or_f16_for_signed!(i8);
41+
from_primitive_or_f16_for_signed!(i16);
42+
from_primitive_or_f16_for_signed!(i32);
43+
from_primitive_or_f16_for_signed!(i64);
3444

3545
impl FromPrimitiveOrF16 for f16 {
3646
fn from_f16(v: f16) -> Option<Self> {

vortex-scalar/src/primitive.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ impl<'a> PrimitiveScalar<'a> {
115115
T::PTYPE
116116
);
117117

118-
self.pvalue.map(|pv| pv.as_primitive::<T>())
118+
self.pvalue.map(|pv| pv.cast::<T>())
119119
}
120120

121121
pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
@@ -126,7 +126,7 @@ impl<'a> PrimitiveScalar<'a> {
126126
Ok(match_each_native_ptype!(ptype, |Q| {
127127
Scalar::primitive(
128128
pvalue
129-
.as_primitive_opt::<Q>()
129+
.cast_opt::<Q>()
130130
.ok_or_else(|| vortex_err!("Cannot cast {} to {}", self.ptype, dtype))?,
131131
dtype.nullability(),
132132
)

vortex-scalar/src/pvalue.rs

Lines changed: 33 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use core::fmt::Display;
55
use std::cmp::Ordering;
66
use std::hash::{Hash, Hasher};
77

8-
use num_traits::NumCast;
8+
use num_traits::{NumCast, ToPrimitive};
99
use paste::paste;
1010
use vortex_dtype::half::f16;
1111
use vortex_dtype::{NativePType, PType, ToBytes};
@@ -122,19 +122,7 @@ macro_rules! as_primitive {
122122
paste! {
123123
#[doc = "Access PValue as `" $T "`, returning `None` if conversion is unsuccessful"]
124124
pub fn [<as_ $T>](self) -> Option<$T> {
125-
match self {
126-
PValue::U8(v) => <$T as NumCast>::from(v),
127-
PValue::U16(v) => <$T as NumCast>::from(v),
128-
PValue::U32(v) => <$T as NumCast>::from(v),
129-
PValue::U64(v) => <$T as NumCast>::from(v),
130-
PValue::I8(v) => <$T as NumCast>::from(v),
131-
PValue::I16(v) => <$T as NumCast>::from(v),
132-
PValue::I32(v) => <$T as NumCast>::from(v),
133-
PValue::I64(v) => <$T as NumCast>::from(v),
134-
PValue::F16(v) => <$T as NumCast>::from(v),
135-
PValue::F32(v) => <$T as NumCast>::from(v),
136-
PValue::F64(v) => <$T as NumCast>::from(v),
137-
}
125+
<$T>::try_from(self).ok()
138126
}
139127
}
140128
};
@@ -152,7 +140,7 @@ impl PValue {
152140
PType::I16 => PValue::I16(0),
153141
PType::I32 => PValue::I32(0),
154142
PType::I64 => PValue::I64(0),
155-
PType::F16 => PValue::F16(f16::from_f32(0.0)),
143+
PType::F16 => PValue::F16(f16::ZERO),
156144
PType::F32 => PValue::F32(0.0),
157145
PType::F64 => PValue::F64(0.0),
158146
}
@@ -184,15 +172,15 @@ impl PValue {
184172
///
185173
/// Panics if the conversion is not supported or would overflow.
186174
#[inline]
187-
pub fn as_primitive<T: NativePType>(&self) -> T {
188-
self.as_primitive_opt::<T>().vortex_expect("as_primitive")
175+
pub fn cast<T: NativePType>(&self) -> T {
176+
self.cast_opt::<T>().vortex_expect("as_primitive")
189177
}
190178

191179
/// Converts this value to a specific native primitive type.
192180
///
193181
/// Returns `None` if the conversion is not supported or would overflow.
194182
#[inline]
195-
pub fn as_primitive_opt<T: NativePType>(&self) -> Option<T> {
183+
pub fn cast_opt<T: NativePType>(&self) -> Option<T> {
196184
match *self {
197185
PValue::U8(u) => T::from_u8(u),
198186
PValue::U16(u) => T::from_u16(u),
@@ -297,16 +285,17 @@ macro_rules! int_pvalue {
297285

298286
fn try_from(value: PValue) -> Result<Self, Self::Error> {
299287
match value {
300-
PValue::U8(v) => <$T as NumCast>::from(v),
301-
PValue::U16(v) => <$T as NumCast>::from(v),
302-
PValue::U32(v) => <$T as NumCast>::from(v),
303-
PValue::U64(v) => <$T as NumCast>::from(v),
304-
PValue::I8(v) => <$T as NumCast>::from(v),
305-
PValue::I16(v) => <$T as NumCast>::from(v),
306-
PValue::I32(v) => <$T as NumCast>::from(v),
307-
PValue::I64(v) => <$T as NumCast>::from(v),
288+
PValue::U8(_)
289+
| PValue::U16(_)
290+
| PValue::U32(_)
291+
| PValue::U64(_)
292+
| PValue::I8(_)
293+
| PValue::I16(_)
294+
| PValue::I32(_)
295+
| PValue::I64(_) => Some(value),
308296
_ => None,
309297
}
298+
.and_then(|v| PValue::cast_opt(&v))
310299
.ok_or_else(|| {
311300
vortex_err!("Cannot read primitive value {:?} as {}", value, PType::$PT)
312301
})
@@ -321,32 +310,29 @@ macro_rules! float_pvalue {
321310
type Error = VortexError;
322311

323312
fn try_from(value: PValue) -> Result<Self, Self::Error> {
324-
match value {
325-
PValue::U8(u) => <Self as NumCast>::from(u),
326-
PValue::U16(u) => <Self as NumCast>::from(u),
327-
PValue::U32(u) => <Self as NumCast>::from(u),
328-
PValue::U64(u) => <Self as NumCast>::from(u),
329-
PValue::I8(i) => <Self as NumCast>::from(i),
330-
PValue::I16(i) => <Self as NumCast>::from(i),
331-
PValue::I32(i) => <Self as NumCast>::from(i),
332-
PValue::I64(i) => <Self as NumCast>::from(i),
333-
PValue::F16(f) => <Self as NumCast>::from(f),
334-
PValue::F32(f) => <Self as NumCast>::from(f),
335-
PValue::F64(f) => <Self as NumCast>::from(f),
336-
}
337-
.ok_or_else(|| {
313+
value.cast_opt().ok_or_else(|| {
338314
vortex_err!("Cannot read primitive value {:?} as {}", value, PType::$PT)
339315
})
340316
}
341317
}
342318
};
343319
}
344320

321+
impl TryFrom<PValue> for usize {
322+
type Error = VortexError;
323+
324+
fn try_from(value: PValue) -> Result<Self, Self::Error> {
325+
value
326+
.cast_opt::<u64>()
327+
.and_then(|v| v.to_usize())
328+
.ok_or_else(|| vortex_err!("Cannot read primitive value {:?} as usize", value))
329+
}
330+
}
331+
345332
int_pvalue!(u8, U8);
346333
int_pvalue!(u16, U16);
347334
int_pvalue!(u32, U32);
348335
int_pvalue!(u64, U64);
349-
int_pvalue!(usize, U64);
350336
int_pvalue!(i8, I8);
351337
int_pvalue!(i16, I16);
352338
int_pvalue!(i32, I32);
@@ -532,8 +518,9 @@ mod test {
532518
use std::cmp::Ordering;
533519
use std::collections::HashSet;
534520

521+
use num_traits::FromPrimitive;
535522
use vortex_dtype::half::f16;
536-
use vortex_dtype::{FromPrimitiveOrF16, NativePType, PType, ToBytes};
523+
use vortex_dtype::{FromPrimitiveOrF16, PType, ToBytes};
537524

538525
use crate::PValue;
539526
use crate::pvalue::CoercePValue;
@@ -915,10 +902,9 @@ mod test {
915902

916903
#[test]
917904
fn test_f16_nans_equal() {
918-
let nan = f16::NAN;
919-
let nan2 = f16::from_le_bytes([154, 253]);
920-
assert!(nan2.is_nan());
921-
let nan3 = f16::from_f16(nan2).unwrap();
922-
assert_eq!(nan2.to_bits(), nan3.to_bits(),);
905+
let nan1 = f16::from_le_bytes([154, 253]);
906+
assert!(nan1.is_nan());
907+
let nan3 = f16::from_f16(nan1).unwrap();
908+
assert_eq!(nan1.to_bits(), nan3.to_bits(),);
923909
}
924910
}

0 commit comments

Comments
 (0)