Skip to content

Commit f93fe4b

Browse files
committed
fix: add FromPrimitiveOrF16 to NativePType replacing FromPrimitive
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 2492c9b commit f93fe4b

File tree

8 files changed

+77
-68
lines changed

8 files changed

+77
-68
lines changed

encodings/fastlanes/src/for/compress.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ use vortex_dtype::{
1414
match_each_unsigned_integer_ptype,
1515
};
1616
use vortex_error::{VortexExpect, VortexResult, vortex_err};
17-
use vortex_scalar::FromPrimitiveOrF16;
1817

1918
use crate::unpack_iter::{UnpackStrategy, UnpackedChunks};
2019
use crate::{BitPackedArray, BitPackedVTable, FoRArray, bitpack_compress};
@@ -102,9 +101,7 @@ pub fn decompress(array: &FoRArray) -> PrimitiveArray {
102101
})
103102
}
104103

105-
fn fused_decompress<
106-
T: PhysicalPType<Physical = T> + UnsignedPType + FoR + FromPrimitiveOrF16 + WrappingAdd,
107-
>(
104+
fn fused_decompress<T: PhysicalPType<Physical = T> + UnsignedPType + FoR + WrappingAdd>(
108105
for_: &FoRArray,
109106
bp: &BitPackedArray,
110107
) -> PrimitiveArray {

vortex-array/src/arrays/chunked/compute/sum.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use num_traits::PrimInt;
55
use vortex_dtype::Nullability::Nullable;
66
use vortex_dtype::{DType, DecimalDType, NativePType, i256, match_each_native_ptype};
77
use vortex_error::{VortexResult, vortex_bail, vortex_err};
8-
use vortex_scalar::{DecimalScalar, DecimalValue, FromPrimitiveOrF16, Scalar};
8+
use vortex_scalar::{DecimalScalar, DecimalValue, Scalar};
99

1010
use crate::arrays::{ChunkedArray, ChunkedVTable};
1111
use crate::compute::{SumKernel, SumKernelAdapter, sum};
@@ -39,9 +39,7 @@ impl SumKernel for ChunkedVTable {
3939

4040
register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
4141

42-
fn sum_int<T: NativePType + PrimInt + FromPrimitiveOrF16>(
43-
chunks: &[ArrayRef],
44-
) -> VortexResult<Option<T>> {
42+
fn sum_int<T: NativePType + PrimInt>(chunks: &[ArrayRef]) -> VortexResult<Option<T>> {
4543
let mut result: T = T::zero();
4644
for chunk in chunks {
4745
let chunk_sum = sum(chunk)?;

vortex-array/src/arrays/constant/compute/sum.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
use num_traits::{CheckedMul, ToPrimitive};
55
use vortex_dtype::{DType, DecimalDType, NativePType, Nullability, i256, match_each_native_ptype};
66
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
7-
use vortex_scalar::{
8-
DecimalScalar, DecimalValue, FromPrimitiveOrF16, PrimitiveScalar, Scalar, ScalarValue,
9-
};
7+
use vortex_scalar::{DecimalScalar, DecimalValue, PrimitiveScalar, Scalar, ScalarValue};
108

119
use crate::arrays::{ConstantArray, ConstantVTable};
1210
use crate::compute::{SumKernel, SumKernelAdapter};
@@ -84,7 +82,7 @@ fn sum_integral<T>(
8482
array_len: usize,
8583
) -> VortexResult<Option<T>>
8684
where
87-
T: FromPrimitiveOrF16 + NativePType + CheckedMul,
85+
T: NativePType + CheckedMul,
8886
Scalar: From<Option<T>>,
8987
{
9088
let v = primitive_scalar.as_::<T>();

vortex-dtype/src/f16.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use half::f16;
5+
use num_traits::FromPrimitive;
6+
7+
/// A trait for types that can be created from primitive values, including f16.
8+
///
9+
/// This extends the `FromPrimitive` trait to also support conversion from f16 values.
10+
pub trait FromPrimitiveOrF16: FromPrimitive {
11+
/// Converts an f16 value to this type, returning None if the conversion fails.
12+
fn from_f16(v: f16) -> Option<Self>;
13+
}
14+
15+
macro_rules! from_primitive_or_f16_for_non_floating_point {
16+
($T:ty) => {
17+
impl FromPrimitiveOrF16 for $T {
18+
fn from_f16(_: f16) -> Option<Self> {
19+
None
20+
}
21+
}
22+
};
23+
}
24+
25+
from_primitive_or_f16_for_non_floating_point!(usize);
26+
from_primitive_or_f16_for_non_floating_point!(u8);
27+
from_primitive_or_f16_for_non_floating_point!(u16);
28+
from_primitive_or_f16_for_non_floating_point!(u32);
29+
from_primitive_or_f16_for_non_floating_point!(u64);
30+
from_primitive_or_f16_for_non_floating_point!(i8);
31+
from_primitive_or_f16_for_non_floating_point!(i16);
32+
from_primitive_or_f16_for_non_floating_point!(i32);
33+
from_primitive_or_f16_for_non_floating_point!(i64);
34+
35+
impl FromPrimitiveOrF16 for f16 {
36+
fn from_f16(v: f16) -> Option<Self> {
37+
Some(v)
38+
}
39+
}
40+
41+
impl FromPrimitiveOrF16 for f32 {
42+
fn from_f16(v: f16) -> Option<Self> {
43+
Some(v.to_f32())
44+
}
45+
}
46+
47+
impl FromPrimitiveOrF16 for f64 {
48+
fn from_f16(v: f16) -> Option<Self> {
49+
Some(v.to_f64())
50+
}
51+
}

vortex-dtype/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub mod datetime;
1818
mod decimal;
1919
mod dtype;
2020
mod extension;
21+
mod f16;
2122
mod field;
2223
mod field_mask;
2324
mod field_names;
@@ -31,6 +32,7 @@ pub use bigint::*;
3132
pub use decimal::*;
3233
pub use dtype::{DType, NativeDType};
3334
pub use extension::*;
35+
pub use f16::*;
3436
pub use field::*;
3537
pub use field_mask::*;
3638
pub use field_names::*;

vortex-dtype/src/ptype.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,12 @@ use std::ops::AddAssign;
1010
use std::panic::RefUnwindSafe;
1111

1212
use num_traits::bounds::UpperBounded;
13-
use num_traits::{
14-
AsPrimitive, Bounded, FromPrimitive, Num, NumCast, PrimInt, ToPrimitive, Unsigned,
15-
};
13+
use num_traits::{AsPrimitive, Bounded, Num, NumCast, PrimInt, ToPrimitive, Unsigned};
1614
use vortex_error::{VortexError, VortexResult, vortex_err};
1715

18-
use crate::DType;
1916
use crate::half::f16;
2017
use crate::nullability::Nullability::NonNullable;
18+
use crate::{DType, FromPrimitiveOrF16};
2119

2220
/// Physical type enum, represents the in-memory physical layout but might represent a different logical type.
2321
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Hash, prost::Enumeration)]
@@ -96,7 +94,7 @@ pub trait NativePType:
9694
+ RefUnwindSafe
9795
+ Num
9896
+ NumCast
99-
+ FromPrimitive
97+
+ FromPrimitiveOrF16
10098
+ ToBytes
10199
+ TryFromBytes
102100
+ private::Sealed

vortex-scalar/src/primitive.rs

Lines changed: 5 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ use std::cmp::Ordering;
66
use std::fmt::{Debug, Display, Formatter};
77
use std::ops::{Add, Sub};
88

9-
use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive};
9+
use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub};
1010
use vortex_dtype::half::f16;
11-
use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
11+
use vortex_dtype::{
12+
DType, FromPrimitiveOrF16, NativePType, Nullability, PType, match_each_native_ptype,
13+
};
1214
use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_err, vortex_panic};
1315

1416
use crate::pvalue::{CoercePValue, PValue};
@@ -104,7 +106,7 @@ impl<'a> PrimitiveScalar<'a> {
104106
/// # Panics
105107
///
106108
/// Panics if the primitive type of this scalar does not match the requested type.
107-
pub fn typed_value<T: NativePType + TryFrom<PValue, Error = VortexError>>(&self) -> Option<T> {
109+
pub fn typed_value<T: NativePType>(&self) -> Option<T> {
108110
assert_eq!(
109111
self.ptype,
110112
T::PTYPE,
@@ -215,52 +217,6 @@ impl<'a> PrimitiveScalar<'a> {
215217
}
216218
}
217219

218-
/// A trait for types that can be created from primitive values, including f16.
219-
///
220-
/// This extends the `FromPrimitive` trait to also support conversion from f16 values.
221-
pub trait FromPrimitiveOrF16: FromPrimitive {
222-
/// Converts an f16 value to this type, returning None if the conversion fails.
223-
fn from_f16(v: f16) -> Option<Self>;
224-
}
225-
226-
macro_rules! from_primitive_or_f16_for_non_floating_point {
227-
($T:ty) => {
228-
impl FromPrimitiveOrF16 for $T {
229-
fn from_f16(_: f16) -> Option<Self> {
230-
None
231-
}
232-
}
233-
};
234-
}
235-
236-
from_primitive_or_f16_for_non_floating_point!(usize);
237-
from_primitive_or_f16_for_non_floating_point!(u8);
238-
from_primitive_or_f16_for_non_floating_point!(u16);
239-
from_primitive_or_f16_for_non_floating_point!(u32);
240-
from_primitive_or_f16_for_non_floating_point!(u64);
241-
from_primitive_or_f16_for_non_floating_point!(i8);
242-
from_primitive_or_f16_for_non_floating_point!(i16);
243-
from_primitive_or_f16_for_non_floating_point!(i32);
244-
from_primitive_or_f16_for_non_floating_point!(i64);
245-
246-
impl FromPrimitiveOrF16 for f16 {
247-
fn from_f16(v: f16) -> Option<Self> {
248-
Some(v)
249-
}
250-
}
251-
252-
impl FromPrimitiveOrF16 for f32 {
253-
fn from_f16(v: f16) -> Option<Self> {
254-
Some(v.to_f32())
255-
}
256-
}
257-
258-
impl FromPrimitiveOrF16 for f64 {
259-
fn from_f16(v: f16) -> Option<Self> {
260-
Some(v.to_f64())
261-
}
262-
}
263-
264220
impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> {
265221
type Error = VortexError;
266222

vortex-scalar/src/pvalue.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ impl PValue {
202202
PValue::I16(i) => T::from_i16(i),
203203
PValue::I32(i) => T::from_i32(i),
204204
PValue::I64(i) => T::from_i64(i),
205-
PValue::F16(f) => <T as NumCast>::from(f),
205+
PValue::F16(f) => T::from_f16(f),
206206
PValue::F32(f) => T::from_f32(f),
207207
PValue::F64(f) => T::from_f64(f),
208208
}
@@ -533,7 +533,7 @@ mod test {
533533
use std::collections::HashSet;
534534

535535
use vortex_dtype::half::f16;
536-
use vortex_dtype::{PType, ToBytes};
536+
use vortex_dtype::{FromPrimitiveOrF16, NativePType, PType, ToBytes};
537537

538538
use crate::PValue;
539539
use crate::pvalue::CoercePValue;
@@ -912,4 +912,13 @@ mod test {
912912
1.0f64 // 0x3ff0000000000000 is 1.0 in f64
913913
);
914914
}
915+
916+
#[test]
917+
fn test_f16_nans_equal() {
918+
let nan = f16::NAN;
919+
let nan2 = f16::from_le_bytes([154, 253]);
920+
assert!(nan2.is_nan());
921+
let nan3 = f16::from_f16(nan2).unwrap();
922+
assert_eq!(nan2.to_bits(), nan3.to_bits(),);
923+
}
915924
}

0 commit comments

Comments
 (0)