diff --git a/Cargo.lock b/Cargo.lock index 5676cfef0a3..508bc4bb84c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6851,6 +6851,7 @@ dependencies = [ "arrow-ord 55.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "futures-util", "libfuzzer-sys", + "strum 0.25.0", "thiserror 2.0.12", "tokio", "vortex-array", diff --git a/Cargo.toml b/Cargo.toml index 414dd5bb4a5..d01c84bd720 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -159,6 +159,7 @@ similar = "2.7.0" simplelog = "0.12" sketches-ddsketch = "0.3.0" static_assertions = "1.1" +strum = "0.25" tabled = { version = "0.19.0", default-features = false } taffy = "0.8.0" tar = "0.4" diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 1a630c928a5..5d46212b224 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -22,6 +22,7 @@ arrow-buffer = { workspace = true } arrow-ord = { workspace = true } futures-util = { workspace = true } libfuzzer-sys = { workspace = true } +strum = { workspace = true, features = ["derive"] } thiserror = { workspace = true } tokio = { workspace = true, features = ["full"] } vortex-array = { workspace = true, features = ["arbitrary"] } diff --git a/fuzz/fuzz_targets/array_ops.rs b/fuzz/fuzz_targets/array_ops.rs index b89dc964f33..29cf651c7c3 100644 --- a/fuzz/fuzz_targets/array_ops.rs +++ b/fuzz/fuzz_targets/array_ops.rs @@ -6,7 +6,7 @@ use vortex_array::arrays::{ BoolEncoding, ConstantArray, ListEncoding, PrimitiveEncoding, StructEncoding, VarBinEncoding, VarBinViewEncoding, }; -use vortex_array::compute::{compare, filter, take}; +use vortex_array::compute::{cast, compare, filter, take}; use vortex_array::search_sorted::{SearchResult, SearchSorted, SearchSortedSide}; use vortex_array::{Array, ArrayRef, IntoArray}; use vortex_btrblocks::BtrBlocksCompressor; @@ -74,6 +74,16 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus { } current_array = compare_result; } + Action::Cast(to) => { + let cast_result = cast(¤t_array, &to).vortex_unwrap(); + if let Err(e) = assert_array_eq(&expected.array(), &cast_result, i) { + vortex_panic!( + "Failed to cast {} to dtype {to}\nError: {e}", + current_array.tree_display() + ) + } + current_array = cast_result; + } } } Corpus::Keep diff --git a/fuzz/src/cast.rs b/fuzz/src/cast.rs new file mode 100644 index 00000000000..94d80a4d2e3 --- /dev/null +++ b/fuzz/src/cast.rs @@ -0,0 +1,32 @@ +use vortex_array::arrays::PrimitiveArray; +use vortex_array::validity::Validity; +use vortex_array::{Array, ArrayRef, ToCanonical}; +use vortex_buffer::Buffer; +use vortex_dtype::{DType, match_each_integer_ptype}; +use vortex_error::VortexResult; + +pub fn cast_canonical_array(array: &ArrayRef, target: &DType) -> VortexResult> { + // TODO(joe): support more casting options + if !target.is_int() || !array.dtype().is_int() { + return Ok(None); + } + Ok(Some(match_each_integer_ptype!( + array.dtype().as_ptype(), + |In| { + match_each_integer_ptype!(target.as_ptype(), |Out| { + // Since the cast itself would truncate. + #[allow(clippy::cast_possible_truncation)] + PrimitiveArray::new( + array + .to_primitive()? + .as_slice::() + .iter() + .map(|v| *v as Out) + .collect::>(), + Validity::from_mask(array.validity_mask()?, target.nullability()), + ) + .to_array() + }) + } + ))) +} diff --git a/fuzz/src/lib.rs b/fuzz/src/lib.rs index 99c8c71423b..7120d2eaadd 100644 --- a/fuzz/src/lib.rs +++ b/fuzz/src/lib.rs @@ -1,5 +1,6 @@ #![feature(error_generic_member_access)] +mod cast; mod compare; pub mod error; mod filter; @@ -10,24 +11,27 @@ mod take; use std::fmt::Debug; use std::iter; -use std::ops::{Range, RangeInclusive}; +use std::ops::Range; use libfuzzer_sys::arbitrary::Error::EmptyChoose; use libfuzzer_sys::arbitrary::{Arbitrary, Result, Unstructured}; pub use sort::sort_canonical_array; +use strum::EnumCount; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::arbitrary::ArbitraryArray; -use vortex_array::compute::Operator; +use vortex_array::compute::{CastOutcome, Operator, allowed_casting}; use vortex_array::search_sorted::{SearchResult, SearchSortedSide}; use vortex_array::{Array, ArrayRef, IntoArray}; use vortex_btrblocks::BtrBlocksCompressor; use vortex_dtype::{DType, Nullability}; -use vortex_error::{VortexUnwrap, vortex_panic}; +use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic}; use vortex_mask::Mask; use vortex_scalar::Scalar; use vortex_scalar::arbitrary::random_scalar; use vortex_utils::aliases::hash_set::HashSet; +use crate::Action::Cast; +use crate::cast::cast_canonical_array; use crate::compare::compare_canonical_array; use crate::filter::filter_canonical_array; use crate::search_sorted::search_sorted_canonical_array; @@ -62,7 +66,7 @@ pub struct FuzzArrayAction { pub actions: Vec<(Action, ExpectedValue)>, } -#[derive(Debug)] +#[derive(Debug, EnumCount)] pub enum Action { Compress, Slice(Range), @@ -70,6 +74,7 @@ pub enum Action { SearchSorted(Scalar, SearchSortedSide), Filter(Mask), Compare(Scalar, Operator), + Cast(DType), } impl<'a> Arbitrary<'a> for FuzzArrayAction { @@ -188,7 +193,21 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction { ExpectedValue::Array(current_array.to_array()), ) } - _ => unreachable!(), + 6 => { + let to: DType = u.arbitrary()?; + if Some(CastOutcome::Infallible) == allowed_casting(current_array.dtype(), &to) + { + return Err(EmptyChoose); + } + let Some(result) = cast_canonical_array(¤t_array, &to) + .vortex_expect("should fail to create array") + else { + return Err(EmptyChoose); + }; + + (Cast(to), ExpectedValue::Array(result)) + } + 7.. => unreachable!(), }) } @@ -217,15 +236,15 @@ fn random_value_from_list(u: &mut Unstructured<'_>, vec: &[usize]) -> Result = 0..=5; +const ALL_ACTIONS: Range = 0..Action::COUNT; fn actions_for_dtype(dtype: &DType) -> HashSet { match dtype { - // All but compare DType::Struct(sdt, _) => sdt .fields() .map(|child| actions_for_dtype(&child)) - .fold((0..=4).collect(), |acc, actions| { + // exclude compare + .fold((0..=4).chain(iter::once(6)).collect(), |acc, actions| { acc.intersection(&actions).copied().collect() }), // Once we support more list operations also recurse here on child dtype diff --git a/vortex-array/src/compute/cast.rs b/vortex-array/src/compute/cast.rs index a1f35ffd3a2..c6fea5d036c 100644 --- a/vortex-array/src/compute/cast.rs +++ b/vortex-array/src/compute/cast.rs @@ -1,7 +1,8 @@ use std::sync::LazyLock; use arcref::ArcRef; -use vortex_dtype::DType; +use vortex_dtype::Nullability::Nullable; +use vortex_dtype::{DType, PType}; use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err}; use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output}; @@ -137,3 +138,67 @@ impl Kernel for CastKernelAdapter { Ok(Some(V::cast(&self.0, array, dtype)?.into())) } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CastOutcome { + Fallible, + Infallible, +} + +pub fn allowed_casting(from: &DType, to: &DType) -> Option { + // Can cast to include nullability + if &from.with_nullability(Nullable) == to { + return Some(CastOutcome::Infallible); + } + match (from, to) { + (DType::Primitive(from_ptype, _), DType::Primitive(to_ptype, _)) => { + allowed_casting_ptype(*from_ptype, *to_ptype) + } + _ => None, + } +} + +pub fn allowed_casting_ptype(from: PType, to: PType) -> Option { + use CastOutcome::*; + use PType::*; + + match (from, to) { + // Identity casts + (a, b) if a == b => Some(Infallible), + + // Integer widening (always infallible) + (U8, U16 | U32 | U64) + | (U16, U32 | U64) + | (U32, U64) + | (I8, I16 | I32 | I64) + | (I16, I32 | I64) + | (I32, I64) => Some(Infallible), + + // Integer narrowing (may truncate) + (U16 | U32 | U64, U8) + | (U32 | U64, U16) + | (U64, U32) + | (I16 | I32 | I64, I8) + | (I32 | I64, I16) + | (I64, I32) => Some(Fallible), + + // Between signed and unsigned (fallible if negative or too big) + (I8 | I16 | I32 | I64, U8 | U16 | U32 | U64) + | (U8 | U16 | U32 | U64, I8 | I16 | I32 | I64) => Some(Fallible), + + // TODO(joe): shall we allow float/int casting? + // Integer -> Float + // (U8 | U16 | U32 | U64 | I8 | I16 | I32 | I64, F16 | F32 | F64) => Some(Fallible), + + // Float -> Integer (truncates, overflows possible) + // (F16 | F32 | F64, U8 | U16 | U32 | U64 | I8 | I16 | I32 | I64) => Some(Fallible), + + // Float widening (safe) + (F16, F32 | F64) | (F32, F64) => Some(Infallible), + + // Float narrowing (lossy) + (F64, F32 | F16) | (F32, F16) => Some(Fallible), + + _ => None, + } +}