@@ -5,20 +5,22 @@ use num_traits::Float;
55use vortex:: array:: ArrayRef ;
66use vortex:: array:: ExecutionCtx ;
77use vortex:: array:: IntoArray ;
8- use vortex:: array:: ToCanonical ;
98use vortex:: array:: arrays:: ExtensionArray ;
109use vortex:: array:: arrays:: FixedSizeListArray ;
1110use vortex:: array:: arrays:: PrimitiveArray ;
12- use vortex:: array:: arrays:: ScalarFnArray ;
1311use vortex:: array:: match_each_float_ptype;
12+ use vortex:: array:: stats:: ArrayStats ;
1413use vortex:: array:: validity:: Validity ;
1514use vortex:: dtype:: DType ;
1615use vortex:: dtype:: Nullability ;
1716use vortex:: dtype:: extension:: ExtDType ;
17+ use vortex:: dtype:: extension:: ExtDTypeRef ;
1818use vortex:: error:: VortexResult ;
1919use vortex:: error:: vortex_ensure;
2020use vortex:: error:: vortex_ensure_eq;
2121use vortex:: error:: vortex_err;
22+ use vortex:: expr:: Expression ;
23+ use vortex:: expr:: root;
2224use vortex:: extension:: EmptyMetadata ;
2325use vortex:: scalar_fn:: EmptyOptions ;
2426use vortex:: scalar_fn:: ScalarFn ;
@@ -34,41 +36,39 @@ use crate::vector::Vector;
3436///
3537/// Each vector in the array is divided by its L2 norm, producing a unit-normalized vector. The
3638/// original norms are stored separately so that the original vectors can be reconstructed.
39+ ///
40+ /// The `vector_array` child carries its own validity and nullability, so a nullable input vector
41+ /// array produces a nullable `NormVectorArray`.
3742#[ derive( Debug , Clone ) ]
3843pub struct NormVectorArray {
3944 /// The backing vector array that has been unit normalized.
4045 ///
41- /// The underlying elements of the vector array must be floating-point.
46+ /// The underlying elements of the vector array must be floating-point. This child may be
47+ /// nullable; its validity determines the validity of the `NormVectorArray`.
4248 pub ( crate ) vector_array : ArrayRef ,
4349
44- /// The L2 (Frobenius) norms of each vector.
50+ /// The L2 norms of each vector.
4551 ///
4652 /// This must have the same dtype as the elements of the vector array.
4753 pub ( crate ) norms : ArrayRef ,
54+
55+ /// Stats set owned by this array.
56+ pub ( crate ) stats_set : ArrayStats ,
4857}
4958
5059impl NormVectorArray {
5160 /// Creates a new [`NormVectorArray`] from a unit-normalized vector array and its L2 norms.
5261 ///
5362 /// The `vector_array` must be a [`Vector`] extension array with floating-point elements, and
54- /// `norms` must be a primitive array of the same float type with the same length.
63+ /// `norms` must be a primitive array of the same float type with the same length. The
64+ /// `vector_array` may be nullable.
5565 pub fn try_new ( vector_array : ArrayRef , norms : ArrayRef ) -> VortexResult < Self > {
56- let ext = vector_array. dtype ( ) . as_extension_opt ( ) . ok_or_else ( || {
57- vortex_err ! (
58- "vector_array dtype must be an extension type, got {}" ,
59- vector_array. dtype( )
60- )
61- } ) ?;
62-
63- vortex_ensure ! (
64- ext. is:: <Vector >( ) ,
65- "vector_array must have the Vector extension type, got {}" ,
66- vector_array. dtype( )
67- ) ;
66+ let ext = Self :: validate ( & vector_array) ?;
6867
69- let element_ptype = extension_element_ptype ( ext) ?;
68+ let element_ptype = extension_element_ptype ( & ext) ?;
7069
71- let expected_norms_dtype = DType :: Primitive ( element_ptype, Nullability :: NonNullable ) ;
70+ let nullability = Nullability :: from ( vector_array. dtype ( ) . is_nullable ( ) ) ;
71+ let expected_norms_dtype = DType :: Primitive ( element_ptype, nullability) ;
7272 vortex_ensure_eq ! (
7373 * norms. dtype( ) ,
7474 expected_norms_dtype,
@@ -84,14 +84,13 @@ impl NormVectorArray {
8484 Ok ( Self {
8585 vector_array,
8686 norms,
87+ stats_set : ArrayStats :: default ( ) ,
8788 } )
8889 }
8990
90- /// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and
91- /// dividing each vector by its norm.
92- ///
93- /// The input must be a [`Vector`] extension array with floating-point elements.
94- pub fn compress ( vector_array : ArrayRef ) -> VortexResult < Self > {
91+ /// Validates that the given array has the [`Vector`] extension type and returns the extension
92+ /// dtype.
93+ fn validate ( vector_array : & ArrayRef ) -> VortexResult < ExtDTypeRef > {
9594 let ext = vector_array. dtype ( ) . as_extension_opt ( ) . ok_or_else ( || {
9695 vortex_err ! (
9796 "vector_array dtype must be an extension type, got {}" ,
@@ -105,19 +104,32 @@ impl NormVectorArray {
105104 vector_array. dtype( )
106105 ) ;
107106
108- let list_size = extension_list_size ( ext) ? ;
109- let row_count = vector_array . len ( ) ;
107+ Ok ( ext. clone ( ) )
108+ }
110109
111- // Compute L2 norms using the scalar function.
112- let l2_norm_fn = ScalarFn :: new ( L2Norm , EmptyOptions ) . erased ( ) ;
113- let norms = ScalarFnArray :: try_new ( l2_norm_fn, vec ! [ vector_array. clone( ) ] , row_count) ?
114- . to_primitive ( )
115- . into_array ( ) ;
110+ /// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and
111+ /// dividing each vector by its norm.
112+ ///
113+ /// The input must be a [`Vector`] extension array with floating-point elements. Nullable inputs
114+ /// are supported; the validity mask is preserved and the normalized data for null rows is
115+ /// unspecified.
116+ pub fn compress ( vector_array : ArrayRef , ctx : & mut ExecutionCtx ) -> VortexResult < Self > {
117+ let ext = Self :: validate ( & vector_array) ?;
118+
119+ let list_size = extension_list_size ( & ext) ?;
120+ let row_count = vector_array. len ( ) ;
121+ let nullability = Nullability :: from ( vector_array. dtype ( ) . is_nullable ( ) ) ;
116122
117- // Divide each vector element by its corresponding norm.
123+ // Compute L2 norms using the scalar function. If the input is nullable, the norms will
124+ // also be nullable (null vectors produce null norms).
118125 let storage = extension_storage ( & vector_array) ?;
119- let flat = extract_flat_elements ( & storage, list_size) ?;
120- let norms_prim = norms. to_canonical ( ) ?. into_primitive ( ) ;
126+ let l2_norm_expr =
127+ Expression :: try_new ( ScalarFn :: new ( L2Norm , EmptyOptions ) . erased ( ) , [ root ( ) ] ) ?;
128+ let norms_prim: PrimitiveArray = vector_array. apply ( & l2_norm_expr) ?. execute ( ctx) ?;
129+ let norms_array = norms_prim. clone ( ) . into_array ( ) ;
130+
131+ // Extract flat elements from the (always non-nullable) storage for normalization.
132+ let flat = extract_flat_elements ( & storage, list_size, ctx) ?;
121133
122134 match_each_float_ptype ! ( flat. ptype( ) , |T | {
123135 let norms_slice = norms_prim. as_slice:: <T >( ) ;
@@ -129,18 +141,20 @@ impl NormVectorArray {
129141 } )
130142 . collect( ) ;
131143
144+ // Reconstruct the vector array with the same nullability as the input.
145+ let validity = Validity :: from( nullability) ;
132146 let fsl = FixedSizeListArray :: new(
133147 normalized_elems. into_array( ) ,
134148 u32 :: try_from( list_size) ?,
135- Validity :: NonNullable ,
149+ validity ,
136150 row_count,
137151 ) ;
138152
139153 let ext_dtype =
140154 ExtDType :: <Vector >:: try_new( EmptyMetadata , fsl. dtype( ) . clone( ) ) ?. erased( ) ;
141155 let normalized_vector = ExtensionArray :: new( ext_dtype, fsl. into_array( ) ) . into_array( ) ;
142156
143- Self :: try_new( normalized_vector, norms )
157+ Self :: try_new( normalized_vector, norms_array )
144158 } )
145159 }
146160
@@ -149,31 +163,26 @@ impl NormVectorArray {
149163 & self . vector_array
150164 }
151165
152- /// Returns a reference to the L2 (Frobenius) norms of each vector.
166+ /// Returns a reference to the L2 norms of each vector.
153167 pub fn norms ( & self ) -> & ArrayRef {
154168 & self . norms
155169 }
156170
157171 /// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm.
158- pub fn decompress ( & self , _ctx : & mut ExecutionCtx ) -> VortexResult < ArrayRef > {
159- let ext_dtype = self
160- . vector_array
161- . dtype ( )
162- . as_extension_opt ( )
163- . ok_or_else ( || {
164- vortex_err ! (
165- "expected Vector extension dtype, got {}" ,
166- self . vector_array. dtype( )
167- )
168- } ) ?;
169-
170- let list_size = extension_list_size ( ext_dtype) ?;
172+ ///
173+ /// The returned array has the same dtype (including nullability) as the original
174+ /// `vector_array` child.
175+ pub fn decompress ( & self , ctx : & mut ExecutionCtx ) -> VortexResult < ArrayRef > {
176+ let ext = Self :: validate ( & self . vector_array ) ?;
177+ let nullability = Nullability :: from ( self . vector_array . dtype ( ) . is_nullable ( ) ) ;
178+
179+ let list_size = extension_list_size ( & ext) ?;
171180 let row_count = self . vector_array . len ( ) ;
172181
173182 let storage = extension_storage ( & self . vector_array ) ?;
174- let flat = extract_flat_elements ( & storage, list_size) ?;
183+ let flat = extract_flat_elements ( & storage, list_size, ctx ) ?;
175184
176- let norms_prim = self . norms . to_canonical ( ) ? . into_primitive ( ) ;
185+ let norms_prim: PrimitiveArray = self . norms . clone ( ) . execute ( ctx ) ? ;
177186
178187 match_each_float_ptype ! ( flat. ptype( ) , |T | {
179188 let norms_slice = norms_prim. as_slice:: <T >( ) ;
@@ -185,10 +194,11 @@ impl NormVectorArray {
185194 } )
186195 . collect( ) ;
187196
197+ let validity = Validity :: from( nullability) ;
188198 let fsl = FixedSizeListArray :: new(
189199 result_elems. into_array( ) ,
190200 u32 :: try_from( list_size) ?,
191- Validity :: NonNullable ,
201+ validity ,
192202 row_count,
193203 ) ;
194204
0 commit comments