Skip to content

Commit 6bf8da6

Browse files
committed
fix problems
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 6912f0f commit 6bf8da6

File tree

9 files changed

+153
-123
lines changed

9 files changed

+153
-123
lines changed

vortex-tensor/public-api.lock

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub type vortex_tensor::encodings::norm::NormVector::Metadata = vortex_array::me
2222

2323
pub type vortex_tensor::encodings::norm::NormVector::OperationsVTable = vortex_tensor::encodings::norm::NormVector
2424

25-
pub type vortex_tensor::encodings::norm::NormVector::ValidityVTable = vortex_array::vtable::validity::ValidityVTableFromChild
25+
pub type vortex_tensor::encodings::norm::NormVector::ValidityVTable = vortex_array::vtable::validity::ValidityVTableFromValidityHelper
2626

2727
pub fn vortex_tensor::encodings::norm::NormVector::array_eq(array: &vortex_tensor::encodings::norm::NormVectorArray, other: &vortex_tensor::encodings::norm::NormVectorArray, precision: vortex_array::hash::Precision) -> bool
2828

@@ -52,7 +52,7 @@ pub fn vortex_tensor::encodings::norm::NormVector::metadata(_array: &vortex_tens
5252

5353
pub fn vortex_tensor::encodings::norm::NormVector::nbuffers(_array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize
5454

55-
pub fn vortex_tensor::encodings::norm::NormVector::nchildren(_array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize
55+
pub fn vortex_tensor::encodings::norm::NormVector::nchildren(array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize
5656

5757
pub fn vortex_tensor::encodings::norm::NormVector::serialize(_metadata: Self::Metadata) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
5858

@@ -66,21 +66,17 @@ impl vortex_array::vtable::operations::OperationsVTable<vortex_tensor::encodings
6666

6767
pub fn vortex_tensor::encodings::norm::NormVector::scalar_at(array: &vortex_tensor::encodings::norm::NormVectorArray, index: usize) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
6868

69-
impl vortex_array::vtable::validity::ValidityChild<vortex_tensor::encodings::norm::NormVector> for vortex_tensor::encodings::norm::NormVector
70-
71-
pub fn vortex_tensor::encodings::norm::NormVector::validity_child(array: &vortex_tensor::encodings::norm::NormVectorArray) -> &vortex_array::array::ArrayRef
72-
7369
pub struct vortex_tensor::encodings::norm::NormVectorArray
7470

7571
impl vortex_tensor::encodings::norm::NormVectorArray
7672

77-
pub fn vortex_tensor::encodings::norm::NormVectorArray::compress(vector_array: vortex_array::array::ArrayRef) -> vortex_error::VortexResult<Self>
73+
pub fn vortex_tensor::encodings::norm::NormVectorArray::compress(vector_array: vortex_array::array::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<Self>
7874

79-
pub fn vortex_tensor::encodings::norm::NormVectorArray::decompress(&self, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
75+
pub fn vortex_tensor::encodings::norm::NormVectorArray::decompress(&self, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
8076

8177
pub fn vortex_tensor::encodings::norm::NormVectorArray::norms(&self) -> &vortex_array::array::ArrayRef
8278

83-
pub fn vortex_tensor::encodings::norm::NormVectorArray::try_new(vector_array: vortex_array::array::ArrayRef, norms: vortex_array::array::ArrayRef) -> vortex_error::VortexResult<Self>
79+
pub fn vortex_tensor::encodings::norm::NormVectorArray::try_new(vector_array: vortex_array::array::ArrayRef, norms: vortex_array::array::ArrayRef, validity: vortex_array::validity::Validity) -> vortex_error::VortexResult<Self>
8480

8581
pub fn vortex_tensor::encodings::norm::NormVectorArray::vector_array(&self) -> &vortex_array::array::ArrayRef
8682

@@ -114,6 +110,10 @@ impl vortex_array::array::IntoArray for vortex_tensor::encodings::norm::NormVect
114110

115111
pub fn vortex_tensor::encodings::norm::NormVectorArray::into_array(self) -> vortex_array::array::ArrayRef
116112

113+
impl vortex_array::vtable::validity::ValidityHelper for vortex_tensor::encodings::norm::NormVectorArray
114+
115+
pub fn vortex_tensor::encodings::norm::NormVectorArray::validity(&self) -> &vortex_array::validity::Validity
116+
117117
pub mod vortex_tensor::fixed_shape
118118

119119
pub struct vortex_tensor::fixed_shape::FixedShapeTensor
@@ -250,7 +250,7 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&se
250250

251251
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName
252252

253-
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
253+
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
254254

255255
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
256256

@@ -280,7 +280,7 @@ pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::arity(&self, _options: &Self:
280280

281281
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName
282282

283-
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
283+
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
284284

285285
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
286286

vortex-tensor/src/encodings/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
pub mod norm;
5-
// mod spherical;
5+
// TODO: Spherical coordinate encoding.

vortex-tensor/src/encodings/norm/array.rs

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,22 @@ use num_traits::Float;
55
use vortex::array::ArrayRef;
66
use vortex::array::ExecutionCtx;
77
use vortex::array::IntoArray;
8-
use vortex::array::ToCanonical;
98
use vortex::array::arrays::ExtensionArray;
109
use vortex::array::arrays::FixedSizeListArray;
1110
use vortex::array::arrays::PrimitiveArray;
12-
use vortex::array::arrays::ScalarFnArray;
1311
use vortex::array::match_each_float_ptype;
12+
use vortex::array::stats::ArrayStats;
1413
use vortex::array::validity::Validity;
1514
use vortex::dtype::DType;
1615
use vortex::dtype::Nullability;
1716
use vortex::dtype::extension::ExtDType;
17+
use vortex::dtype::extension::ExtDTypeRef;
1818
use vortex::error::VortexResult;
1919
use vortex::error::vortex_ensure;
2020
use vortex::error::vortex_ensure_eq;
2121
use vortex::error::vortex_err;
22+
use vortex::expr::Expression;
23+
use vortex::expr::root;
2224
use vortex::extension::EmptyMetadata;
2325
use vortex::scalar_fn::EmptyOptions;
2426
use vortex::scalar_fn::ScalarFn;
@@ -34,41 +36,39 @@ use crate::vector::Vector;
3436
///
3537
/// Each vector in the array is divided by its L2 norm, producing a unit-normalized vector. The
3638
/// original norms are stored separately so that the original vectors can be reconstructed.
39+
///
40+
/// The `vector_array` child carries its own validity and nullability, so a nullable input vector
41+
/// array produces a nullable `NormVectorArray`.
3742
#[derive(Debug, Clone)]
3843
pub struct NormVectorArray {
3944
/// The backing vector array that has been unit normalized.
4045
///
41-
/// The underlying elements of the vector array must be floating-point.
46+
/// The underlying elements of the vector array must be floating-point. This child may be
47+
/// nullable; its validity determines the validity of the `NormVectorArray`.
4248
pub(crate) vector_array: ArrayRef,
4349

44-
/// The L2 (Frobenius) norms of each vector.
50+
/// The L2 norms of each vector.
4551
///
4652
/// This must have the same dtype as the elements of the vector array.
4753
pub(crate) norms: ArrayRef,
54+
55+
/// Stats set owned by this array.
56+
pub(crate) stats_set: ArrayStats,
4857
}
4958

5059
impl NormVectorArray {
5160
/// Creates a new [`NormVectorArray`] from a unit-normalized vector array and its L2 norms.
5261
///
5362
/// The `vector_array` must be a [`Vector`] extension array with floating-point elements, and
54-
/// `norms` must be a primitive array of the same float type with the same length.
63+
/// `norms` must be a primitive array of the same float type with the same length. The
64+
/// `vector_array` may be nullable.
5565
pub fn try_new(vector_array: ArrayRef, norms: ArrayRef) -> VortexResult<Self> {
56-
let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| {
57-
vortex_err!(
58-
"vector_array dtype must be an extension type, got {}",
59-
vector_array.dtype()
60-
)
61-
})?;
62-
63-
vortex_ensure!(
64-
ext.is::<Vector>(),
65-
"vector_array must have the Vector extension type, got {}",
66-
vector_array.dtype()
67-
);
66+
let ext = Self::validate(&vector_array)?;
6867

69-
let element_ptype = extension_element_ptype(ext)?;
68+
let element_ptype = extension_element_ptype(&ext)?;
7069

71-
let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable);
70+
let nullability = Nullability::from(vector_array.dtype().is_nullable());
71+
let expected_norms_dtype = DType::Primitive(element_ptype, nullability);
7272
vortex_ensure_eq!(
7373
*norms.dtype(),
7474
expected_norms_dtype,
@@ -84,14 +84,13 @@ impl NormVectorArray {
8484
Ok(Self {
8585
vector_array,
8686
norms,
87+
stats_set: ArrayStats::default(),
8788
})
8889
}
8990

90-
/// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and
91-
/// dividing each vector by its norm.
92-
///
93-
/// The input must be a [`Vector`] extension array with floating-point elements.
94-
pub fn compress(vector_array: ArrayRef) -> VortexResult<Self> {
91+
/// Validates that the given array has the [`Vector`] extension type and returns the extension
92+
/// dtype.
93+
fn validate(vector_array: &ArrayRef) -> VortexResult<ExtDTypeRef> {
9594
let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| {
9695
vortex_err!(
9796
"vector_array dtype must be an extension type, got {}",
@@ -105,19 +104,32 @@ impl NormVectorArray {
105104
vector_array.dtype()
106105
);
107106

108-
let list_size = extension_list_size(ext)?;
109-
let row_count = vector_array.len();
107+
Ok(ext.clone())
108+
}
110109

111-
// Compute L2 norms using the scalar function.
112-
let l2_norm_fn = ScalarFn::new(L2Norm, EmptyOptions).erased();
113-
let norms = ScalarFnArray::try_new(l2_norm_fn, vec![vector_array.clone()], row_count)?
114-
.to_primitive()
115-
.into_array();
110+
/// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and
111+
/// dividing each vector by its norm.
112+
///
113+
/// The input must be a [`Vector`] extension array with floating-point elements. Nullable inputs
114+
/// are supported; the validity mask is preserved and the normalized data for null rows is
115+
/// unspecified.
116+
pub fn compress(vector_array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Self> {
117+
let ext = Self::validate(&vector_array)?;
118+
119+
let list_size = extension_list_size(&ext)?;
120+
let row_count = vector_array.len();
121+
let nullability = Nullability::from(vector_array.dtype().is_nullable());
116122

117-
// Divide each vector element by its corresponding norm.
123+
// Compute L2 norms using the scalar function. If the input is nullable, the norms will
124+
// also be nullable (null vectors produce null norms).
118125
let storage = extension_storage(&vector_array)?;
119-
let flat = extract_flat_elements(&storage, list_size)?;
120-
let norms_prim = norms.to_canonical()?.into_primitive();
126+
let l2_norm_expr =
127+
Expression::try_new(ScalarFn::new(L2Norm, EmptyOptions).erased(), [root()])?;
128+
let norms_prim: PrimitiveArray = vector_array.apply(&l2_norm_expr)?.execute(ctx)?;
129+
let norms_array = norms_prim.clone().into_array();
130+
131+
// Extract flat elements from the (always non-nullable) storage for normalization.
132+
let flat = extract_flat_elements(&storage, list_size, ctx)?;
121133

122134
match_each_float_ptype!(flat.ptype(), |T| {
123135
let norms_slice = norms_prim.as_slice::<T>();
@@ -129,18 +141,20 @@ impl NormVectorArray {
129141
})
130142
.collect();
131143

144+
// Reconstruct the vector array with the same nullability as the input.
145+
let validity = Validity::from(nullability);
132146
let fsl = FixedSizeListArray::new(
133147
normalized_elems.into_array(),
134148
u32::try_from(list_size)?,
135-
Validity::NonNullable,
149+
validity,
136150
row_count,
137151
);
138152

139153
let ext_dtype =
140154
ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
141155
let normalized_vector = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array();
142156

143-
Self::try_new(normalized_vector, norms)
157+
Self::try_new(normalized_vector, norms_array)
144158
})
145159
}
146160

@@ -149,31 +163,26 @@ impl NormVectorArray {
149163
&self.vector_array
150164
}
151165

152-
/// Returns a reference to the L2 (Frobenius) norms of each vector.
166+
/// Returns a reference to the L2 norms of each vector.
153167
pub fn norms(&self) -> &ArrayRef {
154168
&self.norms
155169
}
156170

157171
/// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm.
158-
pub fn decompress(&self, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
159-
let ext_dtype = self
160-
.vector_array
161-
.dtype()
162-
.as_extension_opt()
163-
.ok_or_else(|| {
164-
vortex_err!(
165-
"expected Vector extension dtype, got {}",
166-
self.vector_array.dtype()
167-
)
168-
})?;
169-
170-
let list_size = extension_list_size(ext_dtype)?;
172+
///
173+
/// The returned array has the same dtype (including nullability) as the original
174+
/// `vector_array` child.
175+
pub fn decompress(&self, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
176+
let ext = Self::validate(&self.vector_array)?;
177+
let nullability = Nullability::from(self.vector_array.dtype().is_nullable());
178+
179+
let list_size = extension_list_size(&ext)?;
171180
let row_count = self.vector_array.len();
172181

173182
let storage = extension_storage(&self.vector_array)?;
174-
let flat = extract_flat_elements(&storage, list_size)?;
183+
let flat = extract_flat_elements(&storage, list_size, ctx)?;
175184

176-
let norms_prim = self.norms.to_canonical()?.into_primitive();
185+
let norms_prim: PrimitiveArray = self.norms.clone().execute(ctx)?;
177186

178187
match_each_float_ptype!(flat.ptype(), |T| {
179188
let norms_slice = norms_prim.as_slice::<T>();
@@ -185,10 +194,11 @@ impl NormVectorArray {
185194
})
186195
.collect();
187196

197+
let validity = Validity::from(nullability);
188198
let fsl = FixedSizeListArray::new(
189199
result_elems.into_array(),
190200
u32::try_from(list_size)?,
191-
Validity::NonNullable,
201+
validity,
192202
row_count,
193203
);
194204

vortex-tensor/src/encodings/norm/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
mod array;
55
pub use array::NormVectorArray;
66

7-
// pub(crate) mod compute;
7+
// TODO: Compute operations for NormVector.
88

99
mod vtable;
1010
pub use vtable::NormVector;

0 commit comments

Comments
 (0)