Skip to content

Commit 454f85a

Browse files
committed
add primitive casting
Signed-off-by: Connor Tsui <[email protected]>
1 parent d90bbda commit 454f85a

File tree

3 files changed

+236
-16
lines changed

3 files changed

+236
-16
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-compute/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ tracing = { workspace = true }
4040

4141
[dev-dependencies]
4242
divan = { workspace = true }
43+
rstest = { workspace = true }
4344

4445
[[bench]]
4546
name = "filter_buffer_mut"

vortex-compute/src/cast/pvector.rs

Lines changed: 234 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use num_traits::NumCast;
5+
use vortex_buffer::Buffer;
6+
use vortex_buffer::BufferMut;
47
use vortex_dtype::DType;
58
use vortex_dtype::NativePType;
9+
use vortex_dtype::match_each_native_ptype;
610
use vortex_error::VortexResult;
711
use vortex_error::vortex_bail;
12+
use vortex_error::vortex_err;
13+
use vortex_mask::AllOr;
14+
use vortex_mask::Mask;
815
use vortex_vector::Scalar;
916
use vortex_vector::ScalarOps;
1017
use vortex_vector::Vector;
1118
use vortex_vector::VectorOps;
1219
use vortex_vector::primitive::PScalar;
1320
use vortex_vector::primitive::PVector;
21+
use vortex_vector::primitive::PrimitiveScalar;
22+
use vortex_vector::primitive::PrimitiveVector;
1423

1524
use crate::cast::Cast;
1625
use crate::cast::try_cast_scalar_common;
@@ -19,26 +28,25 @@ use crate::cast::try_cast_vector_common;
1928
impl<T: NativePType> Cast for PVector<T> {
2029
type Output = Vector;
2130

22-
/// Casts to Primitive (same PType identity).
31+
/// Cast a primitive vector to a different primitive type.
2332
fn cast(&self, target_dtype: &DType) -> VortexResult<Vector> {
2433
if let Some(result) = try_cast_vector_common(self, target_dtype)? {
2534
return Ok(result);
2635
}
2736

2837
match target_dtype {
29-
// We're already the correct PType, and we have compatible nullability.
38+
// We have the same `PType` and we have compatible nullability.
3039
DType::Primitive(target_ptype, n)
3140
if *target_ptype == T::PTYPE && (n.is_nullable() || self.validity().all_true()) =>
3241
{
3342
Ok(self.clone().into())
3443
}
35-
// We're not the correct PType, but we do have compatible nullability.
44+
// We can possibly convert to the target `PType` and we have compatible nullability.
3645
DType::Primitive(target_ptype, n) if n.is_nullable() || self.validity().all_true() => {
37-
vortex_bail!(
38-
"Casting PVector from PType {} to PType {} not yet implemented",
39-
T::PTYPE,
40-
target_ptype
41-
);
46+
match_each_native_ptype!(*target_ptype, |Dst| {
47+
let result = cast_pvector::<T, Dst>(self)?;
48+
Ok(PrimitiveVector::from(result).into())
49+
})
4250
}
4351
_ => {
4452
vortex_bail!("Cannot cast PVector<{}> to {}", T::PTYPE, target_dtype);
@@ -47,33 +55,243 @@ impl<T: NativePType> Cast for PVector<T> {
4755
}
4856
}
4957

58+
/// Cast a [`PVector<F>`] to a [`PVector<T>`] by converting each element.
59+
///
60+
/// Returns an error if any valid element cannot be converted (e.g., overflow).
61+
fn cast_pvector<Src: NativePType, Dst: NativePType>(
62+
src: &PVector<Src>,
63+
) -> VortexResult<PVector<Dst>> {
64+
let elements: &[Src] = src.as_ref();
65+
match src.validity().bit_buffer() {
66+
AllOr::All => {
67+
let mut buffer = BufferMut::with_capacity(elements.len());
68+
for &item in elements {
69+
let converted = <Dst as NumCast>::from(item).ok_or_else(
70+
|| vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, Dst::PTYPE),
71+
)?;
72+
// SAFETY: We pre-allocated the required capacity.
73+
unsafe { buffer.push_unchecked(converted) }
74+
}
75+
Ok(PVector::from(buffer.freeze()))
76+
}
77+
AllOr::None => Ok(PVector::new(
78+
Buffer::zeroed(elements.len()),
79+
Mask::new_false(elements.len()),
80+
)),
81+
AllOr::Some(bit_buffer) => {
82+
let mut buffer = BufferMut::with_capacity(elements.len());
83+
for (&item, valid) in elements.iter().zip(bit_buffer.iter()) {
84+
if valid {
85+
let converted = <Dst as NumCast>::from(item).ok_or_else(
86+
|| vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, Dst::PTYPE),
87+
)?;
88+
// SAFETY: We pre-allocated the required capacity.
89+
unsafe { buffer.push_unchecked(converted) }
90+
} else {
91+
// SAFETY: We pre-allocated the required capacity.
92+
unsafe { buffer.push_unchecked(Dst::default()) }
93+
}
94+
}
95+
Ok(PVector::new(buffer.freeze(), src.validity().clone()))
96+
}
97+
}
98+
}
99+
50100
impl<T: NativePType> Cast for PScalar<T> {
51101
type Output = Scalar;
52102

53-
/// Casts to Primitive (same PType identity).
103+
/// Cast a primitive scalar to a different primitive type.
54104
fn cast(&self, target_dtype: &DType) -> VortexResult<Scalar> {
55105
if let Some(result) = try_cast_scalar_common(self, target_dtype)? {
56106
return Ok(result);
57107
}
58108

59109
match target_dtype {
60-
// We're already the correct PType, and we have compatible nullability.
110+
// We have the same `PType` and we have compatible nullability.
61111
DType::Primitive(target_ptype, n)
62112
if *target_ptype == T::PTYPE && (n.is_nullable() || self.is_valid()) =>
63113
{
64114
Ok(self.clone().into())
65115
}
66-
// We're not the correct PType, but we do have compatible nullability.
116+
// We can possibly convert to the target `PType` and we have compatible nullability.
67117
DType::Primitive(target_ptype, n) if n.is_nullable() || self.is_valid() => {
68-
vortex_bail!(
69-
"Casting PScalar from PType {} to PType {} not yet implemented",
70-
T::PTYPE,
71-
target_ptype
72-
);
118+
match_each_native_ptype!(*target_ptype, |Dst| {
119+
let result = match self.value() {
120+
None => PScalar::null(),
121+
Some(v) => {
122+
let converted = <Dst as NumCast>::from(v).ok_or_else(|| {
123+
vortex_err!(ComputeError: "Failed to cast {} to {:?}", v, Dst::PTYPE)
124+
})?;
125+
PScalar::new(Some(converted))
126+
}
127+
};
128+
Ok(PrimitiveScalar::from(result).into())
129+
})
73130
}
74131
_ => {
75132
vortex_bail!("Cannot cast PScalar<{}> to {}", T::PTYPE, target_dtype);
76133
}
77134
}
78135
}
79136
}
137+
138+
#[cfg(test)]
139+
mod tests {
140+
use rstest::rstest;
141+
use vortex_buffer::BitBuffer;
142+
use vortex_buffer::buffer;
143+
use vortex_dtype::DType;
144+
use vortex_dtype::Nullability;
145+
use vortex_dtype::PType;
146+
use vortex_dtype::PTypeDowncast;
147+
use vortex_error::VortexError;
148+
use vortex_mask::Mask;
149+
use vortex_vector::ScalarOps;
150+
use vortex_vector::VectorOps;
151+
use vortex_vector::primitive::PScalar;
152+
use vortex_vector::primitive::PVector;
153+
154+
use crate::cast::Cast;
155+
156+
#[rstest]
157+
#[case(PType::U8)]
158+
#[case(PType::U16)]
159+
#[case(PType::U32)]
160+
#[case(PType::U64)]
161+
#[case(PType::I8)]
162+
#[case(PType::I16)]
163+
#[case(PType::I32)]
164+
#[case(PType::I64)]
165+
#[case(PType::F32)]
166+
#[case(PType::F64)]
167+
fn cast_u32_to_ptype(#[case] target: PType) {
168+
// Use values that fit in all target types (including i8: -128..127).
169+
let vec: PVector<u32> = buffer![0u32, 10, 100].into();
170+
let result = vec.cast(&target.into()).unwrap();
171+
assert!(result.as_primitive().validity().all_true());
172+
assert_eq!(result.len(), 3);
173+
}
174+
175+
#[test]
176+
fn cast_various_types_to_f64() {
177+
// Test casting from various primitive types to f64.
178+
let u8_vec: PVector<u8> = buffer![0u8, 1, 2, 3, 255].into();
179+
assert!(u8_vec.cast(&PType::F64.into()).is_ok());
180+
181+
let u16_vec: PVector<u16> = buffer![0u16, 100, 1000].into();
182+
assert!(u16_vec.cast(&PType::F64.into()).is_ok());
183+
184+
let u32_vec: PVector<u32> = buffer![0u32, 100, 1000, 1000000].into();
185+
assert!(u32_vec.cast(&PType::F64.into()).is_ok());
186+
187+
let i8_vec: PVector<i8> = buffer![0i8, -1, 1, 127].into();
188+
assert!(i8_vec.cast(&PType::F64.into()).is_ok());
189+
190+
let i32_vec: PVector<i32> = buffer![-1000000i32, -1, 0, 1, 1000000].into();
191+
assert!(i32_vec.cast(&PType::F64.into()).is_ok());
192+
193+
let f32_vec: PVector<f32> = buffer![0.0f32, 1.5, -2.5, 100.0].into();
194+
assert!(f32_vec.cast(&PType::F64.into()).is_ok());
195+
}
196+
197+
#[test]
198+
fn cast_u32_u8() {
199+
let vec: PVector<u32> = buffer![0u32, 10, 200].into();
200+
201+
// Cast from u32 to u8.
202+
let result = vec.cast(&PType::U8.into()).unwrap();
203+
let p = result.into_primitive().into_u8();
204+
assert_eq!(p.as_ref(), &[0u8, 10, 200]);
205+
assert!(p.validity().all_true());
206+
}
207+
208+
#[test]
209+
fn cast_u32_f32() {
210+
let vec: PVector<u32> = buffer![0u32, 10, 200].into();
211+
let result = vec.cast(&PType::F32.into()).unwrap();
212+
let p = result.into_primitive().into_f32();
213+
assert_eq!(p.as_ref(), &[0.0f32, 10., 200.]);
214+
}
215+
216+
#[test]
217+
fn cast_i32_u32_overflow() {
218+
let vec: PVector<i32> = buffer![-1i32].into();
219+
let error = vec.cast(&PType::U32.into()).err().unwrap();
220+
let VortexError::ComputeError(s, _) = error else {
221+
unreachable!()
222+
};
223+
assert_eq!(s.to_string(), "Failed to cast -1 to U32");
224+
}
225+
226+
#[test]
227+
fn cast_with_invalid_nulls() {
228+
// Create a vector with an invalid value at position 0 (which would overflow).
229+
let vec: PVector<i32> = PVector::new(
230+
buffer![-1i32, 0, 10],
231+
Mask::from(BitBuffer::from(vec![false, true, true])),
232+
);
233+
234+
// Cast to nullable u32 should succeed because the invalid value is masked.
235+
let result = vec
236+
.cast(&DType::Primitive(PType::U32, Nullability::Nullable))
237+
.unwrap();
238+
let p = result.into_primitive().into_u32();
239+
assert_eq!(p.as_ref(), &[0u32, 0, 10]);
240+
assert_eq!(
241+
*p.validity(),
242+
Mask::from(BitBuffer::from(vec![false, true, true]))
243+
);
244+
}
245+
246+
#[test]
247+
fn cast_all_null_vector() {
248+
let vec: PVector<i32> = PVector::new(buffer![-1i32, -2, -3], Mask::new_false(3));
249+
250+
// Cast to nullable u32 should succeed because all values are masked.
251+
let result = vec
252+
.cast(&DType::Primitive(PType::U32, Nullability::Nullable))
253+
.unwrap();
254+
let p = result.into_primitive().into_u32();
255+
assert_eq!(p.as_ref(), &[0u32, 0, 0]);
256+
assert!(p.validity().all_false());
257+
}
258+
259+
#[rstest]
260+
#[case(42i32, PType::U32)]
261+
#[case(0i32, PType::U8)]
262+
#[case(255i32, PType::U8)]
263+
#[case(100i32, PType::F64)]
264+
fn cast_scalar_valid(#[case] value: i32, #[case] target: PType) {
265+
let scalar: PScalar<i32> = PScalar::new(Some(value));
266+
let result = scalar.cast(&target.into()).unwrap();
267+
assert!(result.as_primitive().is_valid());
268+
}
269+
270+
#[test]
271+
fn cast_scalar_i32_u32_overflow() {
272+
let scalar: PScalar<i32> = PScalar::new(Some(-1));
273+
let error = scalar.cast(&PType::U32.into()).err().unwrap();
274+
let VortexError::ComputeError(s, _) = error else {
275+
unreachable!()
276+
};
277+
assert_eq!(s.to_string(), "Failed to cast -1 to U32");
278+
}
279+
280+
#[test]
281+
fn cast_scalar_null() {
282+
let scalar: PScalar<i32> = PScalar::null();
283+
let result = scalar
284+
.cast(&DType::Primitive(PType::U32, Nullability::Nullable))
285+
.unwrap();
286+
let p = result.into_primitive().into_u32();
287+
assert_eq!(p.value(), None);
288+
}
289+
290+
#[test]
291+
fn cast_scalar_u32_f64() {
292+
let scalar: PScalar<u32> = PScalar::new(Some(12345));
293+
let result = scalar.cast(&PType::F64.into()).unwrap();
294+
let p = result.into_primitive().into_f64();
295+
assert_eq!(p.value(), Some(12345.0f64));
296+
}
297+
}

0 commit comments

Comments
 (0)