Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion fuzz/fuzz_targets/array_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use vortex_array::arrays::{
BoolEncoding, ConstantArray, ListEncoding, PrimitiveEncoding, StructEncoding, VarBinEncoding,
VarBinViewEncoding,
};
use vortex_array::compute::{compare, filter, take};
use vortex_array::compute::{cast, compare, filter, take};
use vortex_array::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
use vortex_array::{Array, ArrayRef, IntoArray};
use vortex_btrblocks::BtrBlocksCompressor;
Expand Down Expand Up @@ -74,6 +74,16 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus {
}
current_array = compare_result;
}
Action::Cast(to) => {
let cast_result = cast(&current_array, &to).vortex_unwrap();
if let Err(e) = assert_array_eq(&expected.array(), &cast_result, i) {
vortex_panic!(
"Failed to cast {} to dtype {to}\nError: {e}",
current_array.tree_display()
)
}
current_array = cast_result;
}
}
}
Corpus::Keep
Expand Down
32 changes: 32 additions & 0 deletions fuzz/src/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use vortex_array::arrays::PrimitiveArray;
use vortex_array::validity::Validity;
use vortex_array::{Array, ArrayRef, ToCanonical};
use vortex_buffer::Buffer;
use vortex_dtype::{DType, match_each_integer_ptype};
use vortex_error::VortexResult;

pub fn cast_canonical_array(array: &ArrayRef, target: &DType) -> VortexResult<Option<ArrayRef>> {
// TODO(joe): support more casting options
if !target.is_int() || !array.dtype().is_int() {
return Ok(None);
}
Ok(Some(match_each_integer_ptype!(
array.dtype().as_ptype(),
|In| {
match_each_integer_ptype!(target.as_ptype(), |Out| {
// Since the cast itself would truncate.
#[allow(clippy::cast_possible_truncation)]
PrimitiveArray::new(
array
.to_primitive()?
.as_slice::<In>()
.iter()
.map(|v| *v as Out)
.collect::<Buffer<Out>>(),
Validity::from_mask(array.validity_mask()?, target.nullability()),
)
.to_array()
})
}
)))
}
24 changes: 21 additions & 3 deletions fuzz/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![feature(error_generic_member_access)]

mod cast;
mod compare;
pub mod error;
mod filter;
Expand All @@ -16,18 +17,20 @@ use libfuzzer_sys::arbitrary::Error::EmptyChoose;
use libfuzzer_sys::arbitrary::{Arbitrary, Result, Unstructured};
pub use sort::sort_canonical_array;
use vortex_array::arrays::arbitrary::ArbitraryArray;
use vortex_array::compute::Operator;
use vortex_array::compute::{CastOutcome, Operator, allowed_casting};
use vortex_array::search_sorted::{SearchResult, SearchSortedSide};
use vortex_array::{Array, ArrayRef, IntoArray};
use vortex_btrblocks::BtrBlocksCompressor;
use vortex_buffer::Buffer;
use vortex_dtype::DType;
use vortex_error::{VortexUnwrap, vortex_panic};
use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
use vortex_mask::Mask;
use vortex_scalar::Scalar;
use vortex_scalar::arbitrary::random_scalar;
use vortex_utils::aliases::hash_set::HashSet;

use crate::Action::Cast;
use crate::cast::cast_canonical_array;
use crate::compare::compare_canonical_array;
use crate::filter::filter_canonical_array;
use crate::search_sorted::search_sorted_canonical_array;
Expand Down Expand Up @@ -70,6 +73,7 @@ pub enum Action {
SearchSorted(Scalar, SearchSortedSide),
Filter(Mask),
Compare(Scalar, Operator),
Cast(DType),
}

impl<'a> Arbitrary<'a> for FuzzArrayAction {
Expand Down Expand Up @@ -186,7 +190,21 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction {
ExpectedValue::Array(current_array.to_array()),
)
}
_ => unreachable!(),
6 => {
let to: DType = u.arbitrary()?;
if Some(CastOutcome::Infallible) == allowed_casting(current_array.dtype(), &to)
{
return Err(EmptyChoose);
}
let Some(result) = cast_canonical_array(&current_array, &to)
.vortex_expect("should fail to create array")
else {
return Err(EmptyChoose);
};

(Cast(to), ExpectedValue::Array(result))
}
7.. => unreachable!(),
})
}

Expand Down
67 changes: 66 additions & 1 deletion vortex-array/src/compute/cast.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::sync::LazyLock;

use arcref::ArcRef;
use vortex_dtype::DType;
use vortex_dtype::Nullability::Nullable;
use vortex_dtype::{DType, PType};
use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};

use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
Expand Down Expand Up @@ -137,3 +138,67 @@ impl<V: VTable + CastKernel> Kernel for CastKernelAdapter<V> {
Ok(Some(V::cast(&self.0, array, dtype)?.into()))
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CastOutcome {
Fallible,
Infallible,
}

pub fn allowed_casting(from: &DType, to: &DType) -> Option<CastOutcome> {
// Can cast to include nullability
if &from.with_nullability(Nullable) == to {
return Some(CastOutcome::Infallible);
}
match (from, to) {
(DType::Primitive(from_ptype, _), DType::Primitive(to_ptype, _)) => {
allowed_casting_ptype(*from_ptype, *to_ptype)
}
_ => None,
}
}

pub fn allowed_casting_ptype(from: PType, to: PType) -> Option<CastOutcome> {
use CastOutcome::*;
use PType::*;

match (from, to) {
// Identity casts
(a, b) if a == b => Some(Infallible),

// Integer widening (always infallible)
(U8, U16 | U32 | U64)
| (U16, U32 | U64)
| (U32, U64)
| (I8, I16 | I32 | I64)
| (I16, I32 | I64)
| (I32, I64) => Some(Infallible),

// Integer narrowing (may truncate)
(U16 | U32 | U64, U8)
| (U32 | U64, U16)
| (U64, U32)
| (I16 | I32 | I64, I8)
| (I32 | I64, I16)
| (I64, I32) => Some(Fallible),

// Between signed and unsigned (fallible if negative or too big)
(I8 | I16 | I32 | I64, U8 | U16 | U32 | U64)
| (U8 | U16 | U32 | U64, I8 | I16 | I32 | I64) => Some(Fallible),

// TODO(joe): shall we allow float/int casting?
// Integer -> Float
// (U8 | U16 | U32 | U64 | I8 | I16 | I32 | I64, F16 | F32 | F64) => Some(Fallible),

// Float -> Integer (truncates, overflows possible)
// (F16 | F32 | F64, U8 | U16 | U32 | U64 | I8 | I16 | I32 | I64) => Some(Fallible),

// Float widening (safe)
(F16, F32 | F64) | (F32, F64) => Some(Infallible),

// Float narrowing (lossy)
(F64, F32 | F16) | (F32, F16) => Some(Fallible),

_ => None,
}
}
Loading