Skip to content

Commit 5b99fa6

Browse files
committed
feat: added validator functions for extension arrays
Signed-off-by: Pratham Agarwal <[email protected]>
1 parent 0e7ab69 commit 5b99fa6

File tree

6 files changed

+255
-31
lines changed

6 files changed

+255
-31
lines changed

fuzz/src/array/mod.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,9 @@ pub fn assert_array_eq(
662662
rhs: &ArrayRef,
663663
step: usize,
664664
) -> crate::error::VortexFuzzResult<()> {
665+
use vortex_array::ToCanonical;
666+
use vortex_array::arrays::validator_for_ext_type;
667+
665668
use crate::error::Backtrace;
666669
use crate::error::VortexFuzzError;
667670

@@ -699,7 +702,22 @@ pub fn assert_array_eq(
699702
Backtrace::capture(),
700703
));
701704
}
705+
706+
// Also validate the expected array's domain constraints for extension types
707+
if matches!(lhs.dtype(), DType::Extension(..)) {
708+
let validator = validator_for_ext_type(lhs.to_extension().ext_dtype());
709+
if !validator(&l) {
710+
return Err(VortexFuzzError::DomainValidationFailed(
711+
l,
712+
idx,
713+
lhs.clone(),
714+
step,
715+
Backtrace::capture(),
716+
));
717+
}
718+
}
702719
}
720+
703721
Ok(())
704722
}
705723

fuzz/src/error.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ pub enum VortexFuzzError {
5959
LengthMismatch(usize, usize, ArrayRef, ArrayRef, usize, Backtrace),
6060

6161
VortexError(VortexError, Backtrace),
62+
63+
DomainValidationFailed(Scalar, usize, ArrayRef, usize, Backtrace),
6264
}
6365

6466
impl Debug for VortexFuzzError {
@@ -121,6 +123,13 @@ impl Display for VortexFuzzError {
121123
rhs.display_tree(),
122124
)
123125
}
126+
VortexFuzzError::DomainValidationFailed(expected, idx, lhs, step, backtrace) => {
127+
write!(
128+
f,
129+
"Domain validation failed:\n Scalar: {expected}\n Index: {idx}\n Array: {}\n Step: {step}\nBacktrace:\n{backtrace}",
130+
lhs.display_tree(),
131+
)
132+
}
124133
VortexFuzzError::VortexError(err, backtrace) => {
125134
write!(f, "{err}\nBacktrace:\n{backtrace}")
126135
}
@@ -137,6 +146,7 @@ impl Error for VortexFuzzError {
137146
| VortexFuzzError::LengthMismatch(..)
138147
| VortexFuzzError::ScalarMismatch(..)
139148
| VortexFuzzError::MinMaxMismatch(..)
149+
| VortexFuzzError::DomainValidationFailed(..)
140150
| VortexFuzzError::DTypeMismatch(..) => None,
141151
}
142152
}

vortex-array/src/arrays/arbitrary.rs

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ use crate::IntoArray;
3131
use crate::ToCanonical;
3232
use crate::arrays::VarBinArray;
3333
use crate::arrays::VarBinViewArray;
34+
use crate::arrays::validator_for_ext_type;
3435
use crate::builders::ArrayBuilder;
3536
use crate::builders::DecimalBuilder;
3637
use crate::builders::FixedSizeListBuilder;
@@ -170,7 +171,6 @@ fn random_array_chunk(
170171
}
171172

172173
/// Creates a random extension array.
173-
///
174174
/// If the `chunk_len` is specified, the length of the array will be equal to the chunk length.
175175
fn random_extension(
176176
u: &mut Unstructured,
@@ -179,24 +179,34 @@ fn random_extension(
179179
) -> Result<ArrayRef> {
180180
use crate::builders::ExtensionBuilder;
181181

182+
// Get the validator for this extension type
183+
let validator = validator_for_ext_type(ext_dtype);
184+
182185
// Determine array length
183186
let array_length = chunk_len.unwrap_or(u.int_in_range(0..=20)?);
184187

185188
// Create builder for the extension array
186189
let mut builder = ExtensionBuilder::with_capacity(ext_dtype.clone(), array_length);
187190

188-
// Generate random values
191+
// Generate random values, retrying if they don't pass validation
189192
for _ in 0..array_length {
190-
// Wrap in extension scalar using Scalar::extension()
191-
let ext_scalar = Scalar::extension(
192-
ext_dtype.clone(),
193-
random_scalar(u, ext_dtype.storage_dtype())?,
194-
);
195-
196-
// Append to builder
197-
builder
198-
.append_scalar(&ext_scalar)
199-
.vortex_expect("can append extension scalar");
193+
// Retry loop to generate valid values
194+
loop {
195+
// Generate a random storage scalar
196+
let storage_scalar = random_scalar(u, ext_dtype.storage_dtype())?;
197+
198+
// Wrap it in an extension scalar
199+
let ext_scalar = Scalar::extension(ext_dtype.clone(), storage_scalar);
200+
201+
// Validate the scalar
202+
if validator(&ext_scalar) {
203+
// Valid value - append and break
204+
builder
205+
.append_scalar(&ext_scalar)
206+
.vortex_expect("can append extension scalar");
207+
break;
208+
}
209+
}
200210
}
201211

202212
Ok(builder.finish())

vortex-array/src/arrays/extension/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,6 @@ mod compute;
88

99
mod vtable;
1010
pub use vtable::ExtensionVTable;
11+
12+
mod validate;
13+
pub use validate::validator_for_ext_type;
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Domain validation for extension types.
5+
//!
6+
//! This module provides validators to ensure that extension array values
7+
//! are always within the valid domain for their type. For example, temporal
8+
//! arrays should not overflow when converted to Jiff types.
9+
10+
use vortex_dtype::ExtDType;
11+
use vortex_dtype::PType;
12+
use vortex_dtype::datetime::TemporalMetadata;
13+
use vortex_dtype::datetime::is_temporal_ext_type;
14+
use vortex_error::VortexExpect;
15+
use vortex_scalar::Scalar;
16+
17+
/// Type alias for a domain validator function.
18+
///
19+
/// A domain validator checks whether a scalar value is valid for a given extension type.
20+
/// For example, temporal extension types validate that values don't overflow when converted to Jiff types.
21+
pub type DomainValidator = Box<dyn Fn(&Scalar) -> bool + Send + Sync>;
22+
23+
/// Creates a domain validator for the given extension type.
24+
///
25+
/// This function returns a validator that checks if scalar values are in the valid domain
26+
/// for the extension type. For temporal types (date, time, timestamp), it validates that
27+
/// the values can be successfully converted to Jiff types without overflow.
28+
///
29+
/// # Examples
30+
///
31+
/// ```
32+
/// use std::sync::Arc;
33+
/// use vortex_array::arrays::extension::validator_for_ext_type;
34+
/// use vortex_dtype::{ExtDType, ExtMetadata, DType, PType, Nullability};
35+
/// use vortex_dtype::datetime::{TemporalMetadata, TimeUnit, DATE_ID};
36+
/// use vortex_scalar::Scalar;
37+
///
38+
/// let metadata: ExtMetadata = TemporalMetadata::Date(TimeUnit::Days).into();
39+
/// let ext_dtype = ExtDType::new(
40+
/// DATE_ID.clone(),
41+
/// Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
42+
/// Some(metadata),
43+
/// );
44+
///
45+
/// let validator = validator_for_ext_type(&ext_dtype);
46+
///
47+
/// // Valid date value
48+
/// let valid_scalar = Scalar::extension(
49+
/// Arc::new(ext_dtype.clone()),
50+
/// Scalar::primitive(18000i32, Nullability::NonNullable),
51+
/// );
52+
/// assert!(validator(&valid_scalar));
53+
///
54+
/// // Null is always valid
55+
/// let null_scalar = Scalar::null(DType::Extension(Arc::new(ext_dtype)));
56+
/// assert!(validator(&null_scalar));
57+
/// ```
58+
pub fn validator_for_ext_type(ext_dtype: &ExtDType) -> DomainValidator {
59+
if is_temporal_ext_type(ext_dtype.id()) {
60+
// For temporal types, validate that the value can be converted to Jiff
61+
let metadata = TemporalMetadata::try_from(ext_dtype)
62+
.vortex_expect("temporal ext_dtype should have valid metadata");
63+
64+
Box::new(move |scalar: &Scalar| {
65+
if scalar.is_null() {
66+
return true;
67+
}
68+
69+
// Extract the storage value and validate it can be converted to Jiff
70+
let ext_scalar = scalar.as_extension();
71+
let storage = ext_scalar.storage();
72+
let primitive = storage.as_primitive();
73+
74+
// Get the i64 value from the primitive (temporal types use i32 or i64)
75+
let value = match primitive.ptype() {
76+
PType::I32 => primitive.typed_value::<i32>().map(|v| v as i64),
77+
PType::I64 => primitive.typed_value::<i64>(),
78+
_ => None,
79+
};
80+
81+
value.map(|v| metadata.to_jiff(v).is_ok()).unwrap_or(false)
82+
})
83+
} else {
84+
// Unknown extension type - accept all values
85+
Box::new(|_| true)
86+
}
87+
}
88+
89+
#[cfg(test)]
90+
mod tests {
91+
use std::sync::Arc;
92+
93+
use vortex_dtype::DType;
94+
use vortex_dtype::ExtDType;
95+
use vortex_dtype::ExtMetadata;
96+
use vortex_dtype::Nullability;
97+
use vortex_dtype::PType;
98+
use vortex_dtype::datetime::DATE_ID;
99+
use vortex_dtype::datetime::TemporalMetadata;
100+
use vortex_dtype::datetime::TimeUnit;
101+
use vortex_scalar::Scalar;
102+
103+
use super::*;
104+
105+
#[test]
106+
fn test_temporal_validator_accepts_valid_values() {
107+
let metadata: ExtMetadata = TemporalMetadata::Date(TimeUnit::Days).into();
108+
let ext_dtype = ExtDType::new(
109+
DATE_ID.clone(),
110+
Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
111+
Some(metadata),
112+
);
113+
114+
let validator = validator_for_ext_type(&ext_dtype);
115+
116+
// Valid date (days since epoch)
117+
let valid_scalar = Scalar::extension(
118+
Arc::new(ext_dtype.clone()),
119+
Scalar::primitive(18000i32, Nullability::NonNullable),
120+
);
121+
assert!(validator(&valid_scalar));
122+
}
123+
124+
#[test]
125+
fn test_temporal_validator_accepts_null() {
126+
let metadata: ExtMetadata = TemporalMetadata::Date(TimeUnit::Days).into();
127+
let ext_dtype = ExtDType::new(
128+
DATE_ID.clone(),
129+
Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)),
130+
Some(metadata),
131+
);
132+
133+
let validator = validator_for_ext_type(&ext_dtype);
134+
135+
let null_scalar = Scalar::null(DType::Extension(Arc::new(ext_dtype)));
136+
assert!(validator(&null_scalar));
137+
}
138+
}

vortex-dtype/src/arbitrary/mod.rs

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@ use arbitrary::Unstructured;
1010
use crate::DType;
1111
use crate::DecimalDType;
1212
use crate::ExtDType;
13-
use crate::ExtID;
1413
use crate::FieldName;
1514
use crate::FieldNames;
1615
use crate::NativeDecimalType;
1716
use crate::Nullability;
1817
use crate::PType;
1918
use crate::StructFields;
19+
use crate::datetime::DATE_ID;
20+
use crate::datetime::TIME_ID;
21+
use crate::datetime::TIMESTAMP_ID;
22+
use crate::datetime::TemporalMetadata;
23+
use crate::datetime::TimeUnit;
2024
use crate::i256;
2125
mod decimal;
2226

@@ -112,24 +116,65 @@ impl<'a> Arbitrary<'a> for ExtDType {
112116
}
113117

114118
fn random_ext_dtype(u: &mut Unstructured<'_>, _depth: u8) -> Result<ExtDType> {
115-
let id_str = u.choose(&[
116-
"test.ext",
117-
"example.currency",
118-
"example.uuid",
119-
"example.ipv4",
120-
])?;
121-
let ext_id = ExtID::from(*id_str);
122-
123-
// Supports only base types for now
124-
let storage_dtype = match u.int_in_range(1..=3)? {
125-
// base types
126-
1 => DType::Bool(u.arbitrary()?),
127-
2 => DType::Primitive(u.arbitrary()?, u.arbitrary()?),
128-
3 => DType::Decimal(u.arbitrary()?, u.arbitrary()?),
129-
_ => unreachable!("int_in_range(1..=3) returned value out of range"),
130-
};
131-
132-
Ok(ExtDType::new(ext_id, storage_dtype.into(), None))
119+
let choice = u.int_in_range(0..=2)?;
120+
121+
match choice {
122+
0 => {
123+
// DATE: i32 (Days) or i64 (Milliseconds)
124+
let (ptype, time_unit) = match u.int_in_range(0..=1)? {
125+
0 => (PType::I32, TimeUnit::Days),
126+
1 => (PType::I64, TimeUnit::Milliseconds),
127+
_ => unreachable!(),
128+
};
129+
130+
Ok(ExtDType::new(
131+
DATE_ID.clone(),
132+
DType::Primitive(ptype, u.arbitrary()?).into(),
133+
Some(TemporalMetadata::Date(time_unit).into()),
134+
))
135+
}
136+
1 => {
137+
// TIME: i32 for Seconds/Milliseconds, i64 for Microseconds/Nanoseconds
138+
let (ptype, time_unit) = match u.int_in_range(0..=3)? {
139+
0 => (PType::I32, TimeUnit::Seconds),
140+
1 => (PType::I32, TimeUnit::Milliseconds),
141+
2 => (PType::I64, TimeUnit::Microseconds),
142+
3 => (PType::I64, TimeUnit::Nanoseconds),
143+
_ => unreachable!(),
144+
};
145+
146+
Ok(ExtDType::new(
147+
TIME_ID.clone(),
148+
DType::Primitive(ptype, u.arbitrary()?).into(),
149+
Some(TemporalMetadata::Time(time_unit).into()),
150+
))
151+
}
152+
2 => {
153+
// TIMESTAMP: always i64 with time unit and optional timezone
154+
let time_unit = match u.int_in_range(0..=3)? {
155+
0 => TimeUnit::Seconds,
156+
1 => TimeUnit::Milliseconds,
157+
2 => TimeUnit::Microseconds,
158+
3 => TimeUnit::Nanoseconds,
159+
_ => unreachable!(),
160+
};
161+
162+
let time_zone = u
163+
.arbitrary::<bool>()?
164+
.then(|| {
165+
u.choose(&["UTC", "America/New_York", "Europe/London", "Asia/Tokyo"])
166+
.map(|s| s.to_string())
167+
})
168+
.transpose()?;
169+
170+
Ok(ExtDType::new(
171+
TIMESTAMP_ID.clone(),
172+
DType::Primitive(PType::I64, u.arbitrary()?).into(),
173+
Some(TemporalMetadata::Timestamp(time_unit, time_zone).into()),
174+
))
175+
}
176+
_ => unreachable!(),
177+
}
133178
}
134179

135180
impl<'a> Arbitrary<'a> for StructFields {

0 commit comments

Comments
 (0)