11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ use num_traits:: Float ;
45use vortex:: array:: ArrayRef ;
56use vortex:: array:: ExecutionCtx ;
7+ use vortex:: array:: IntoArray ;
8+ use vortex:: array:: ToCanonical ;
9+ use vortex:: array:: arrays:: ExtensionArray ;
10+ use vortex:: array:: arrays:: FixedSizeListArray ;
11+ use vortex:: array:: arrays:: PrimitiveArray ;
12+ use vortex:: array:: arrays:: ScalarFnArray ;
13+ use vortex:: array:: match_each_float_ptype;
14+ use vortex:: array:: validity:: Validity ;
615use vortex:: dtype:: DType ;
716use vortex:: dtype:: Nullability ;
17+ use vortex:: dtype:: extension:: ExtDType ;
818use vortex:: error:: VortexResult ;
919use vortex:: error:: vortex_ensure;
1020use vortex:: error:: vortex_ensure_eq;
1121use vortex:: error:: vortex_err;
22+ use vortex:: extension:: EmptyMetadata ;
23+ use vortex:: scalar_fn:: EmptyOptions ;
24+ use vortex:: scalar_fn:: ScalarFn ;
1225
26+ use crate :: scalar_fns:: l2_norm:: L2Norm ;
1327use crate :: utils:: extension_element_ptype;
28+ use crate :: utils:: extension_list_size;
29+ use crate :: utils:: extension_storage;
30+ use crate :: utils:: extract_flat_elements;
1431use crate :: vector:: Vector ;
1532
1633/// A normalized array that stores unit-normalized vectors alongside their original L2 norms.
@@ -70,6 +87,63 @@ impl NormVectorArray {
7087 } )
7188 }
7289
90+ /// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and
91+ /// dividing each vector by its norm.
92+ ///
93+ /// The input must be a [`Vector`] extension array with floating-point elements.
94+ pub fn compress ( vector_array : ArrayRef ) -> VortexResult < Self > {
95+ let ext = vector_array. dtype ( ) . as_extension_opt ( ) . ok_or_else ( || {
96+ vortex_err ! (
97+ "vector_array dtype must be an extension type, got {}" ,
98+ vector_array. dtype( )
99+ )
100+ } ) ?;
101+
102+ vortex_ensure ! (
103+ ext. is:: <Vector >( ) ,
104+ "vector_array must have the Vector extension type, got {}" ,
105+ vector_array. dtype( )
106+ ) ;
107+
108+ let list_size = extension_list_size ( ext) ?;
109+ let row_count = vector_array. len ( ) ;
110+
111+ // Compute L2 norms using the scalar function.
112+ let l2_norm_fn = ScalarFn :: new ( L2Norm , EmptyOptions ) . erased ( ) ;
113+ let norms = ScalarFnArray :: try_new ( l2_norm_fn, vec ! [ vector_array. clone( ) ] , row_count) ?
114+ . to_primitive ( )
115+ . into_array ( ) ;
116+
117+ // Divide each vector element by its corresponding norm.
118+ let storage = extension_storage ( & vector_array) ?;
119+ let flat = extract_flat_elements ( & storage, list_size) ?;
120+ let norms_prim = norms. to_canonical ( ) ?. into_primitive ( ) ;
121+
122+ match_each_float_ptype ! ( flat. ptype( ) , |T | {
123+ let norms_slice = norms_prim. as_slice:: <T >( ) ;
124+
125+ let normalized_elems: PrimitiveArray = ( 0 ..row_count)
126+ . flat_map( |i| {
127+ let inv_norm = safe_inv_norm( norms_slice[ i] ) ;
128+ flat. row:: <T >( i) . iter( ) . map( move |& v| v * inv_norm)
129+ } )
130+ . collect( ) ;
131+
132+ let fsl = FixedSizeListArray :: new(
133+ normalized_elems. into_array( ) ,
134+ u32 :: try_from( list_size) ?,
135+ Validity :: NonNullable ,
136+ row_count,
137+ ) ;
138+
139+ let ext_dtype =
140+ ExtDType :: <Vector >:: try_new( EmptyMetadata , fsl. dtype( ) . clone( ) ) ?. erased( ) ;
141+ let normalized_vector = ExtensionArray :: new( ext_dtype, fsl. into_array( ) ) . into_array( ) ;
142+
143+ Self :: try_new( normalized_vector, norms)
144+ } )
145+ }
146+
73147 /// Returns a reference to the backing vector array that has been unit normalized.
74148 pub fn vector_array ( & self ) -> & ArrayRef {
75149 & self . vector_array
@@ -80,8 +154,58 @@ impl NormVectorArray {
80154 & self . norms
81155 }
82156
83- // TODO docs
84- pub ( super ) fn execute_into_vector ( & self , ctx : & mut ExecutionCtx ) -> VortexResult < ArrayRef > {
85- todo ! ( )
157+ /// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm.
158+ pub fn decompress ( & self , _ctx : & mut ExecutionCtx ) -> VortexResult < ArrayRef > {
159+ let ext_dtype = self
160+ . vector_array
161+ . dtype ( )
162+ . as_extension_opt ( )
163+ . ok_or_else ( || {
164+ vortex_err ! (
165+ "expected Vector extension dtype, got {}" ,
166+ self . vector_array. dtype( )
167+ )
168+ } ) ?;
169+
170+ let list_size = extension_list_size ( ext_dtype) ?;
171+ let row_count = self . vector_array . len ( ) ;
172+
173+ let storage = extension_storage ( & self . vector_array ) ?;
174+ let flat = extract_flat_elements ( & storage, list_size) ?;
175+
176+ let norms_prim = self . norms . to_canonical ( ) ?. into_primitive ( ) ;
177+
178+ match_each_float_ptype ! ( flat. ptype( ) , |T | {
179+ let norms_slice = norms_prim. as_slice:: <T >( ) ;
180+
181+ let result_elems: PrimitiveArray = ( 0 ..row_count)
182+ . flat_map( |i| {
183+ let norm = norms_slice[ i] ;
184+ flat. row:: <T >( i) . iter( ) . map( move |& v| v * norm)
185+ } )
186+ . collect( ) ;
187+
188+ let fsl = FixedSizeListArray :: new(
189+ result_elems. into_array( ) ,
190+ u32 :: try_from( list_size) ?,
191+ Validity :: NonNullable ,
192+ row_count,
193+ ) ;
194+
195+ let ext_dtype =
196+ ExtDType :: <Vector >:: try_new( EmptyMetadata , fsl. dtype( ) . clone( ) ) ?. erased( ) ;
197+ Ok ( ExtensionArray :: new( ext_dtype, fsl. into_array( ) ) . into_array( ) )
198+ } )
199+ }
200+ }
201+
202+ /// Returns `1 / norm` if the norm is non-zero, or zero otherwise.
203+ ///
204+ /// This avoids division by zero for zero-length or all-zero vectors.
205+ fn safe_inv_norm < T : Float > ( norm : T ) -> T {
206+ if norm == T :: zero ( ) {
207+ T :: zero ( )
208+ } else {
209+ T :: one ( ) / norm
86210 }
87211}
0 commit comments