Skip to content

Commit cf65eae

Browse files
committed
implement compress and decompress
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 39e0a0e commit cf65eae

File tree

4 files changed

+240
-6
lines changed

4 files changed

+240
-6
lines changed

vortex-tensor/src/encodings/norm/array.rs

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,33 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use num_traits::Float;
45
use vortex::array::ArrayRef;
56
use vortex::array::ExecutionCtx;
7+
use vortex::array::IntoArray;
8+
use vortex::array::ToCanonical;
9+
use vortex::array::arrays::ExtensionArray;
10+
use vortex::array::arrays::FixedSizeListArray;
11+
use vortex::array::arrays::PrimitiveArray;
12+
use vortex::array::arrays::ScalarFnArray;
13+
use vortex::array::match_each_float_ptype;
14+
use vortex::array::validity::Validity;
615
use vortex::dtype::DType;
716
use vortex::dtype::Nullability;
17+
use vortex::dtype::extension::ExtDType;
818
use vortex::error::VortexResult;
919
use vortex::error::vortex_ensure;
1020
use vortex::error::vortex_ensure_eq;
1121
use vortex::error::vortex_err;
22+
use vortex::extension::EmptyMetadata;
23+
use vortex::scalar_fn::EmptyOptions;
24+
use vortex::scalar_fn::ScalarFn;
1225

26+
use crate::scalar_fns::l2_norm::L2Norm;
1327
use crate::utils::extension_element_ptype;
28+
use crate::utils::extension_list_size;
29+
use crate::utils::extension_storage;
30+
use crate::utils::extract_flat_elements;
1431
use crate::vector::Vector;
1532

1633
/// A normalized array that stores unit-normalized vectors alongside their original L2 norms.
@@ -70,6 +87,63 @@ impl NormVectorArray {
7087
})
7188
}
7289

90+
/// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and
91+
/// dividing each vector by its norm.
92+
///
93+
/// The input must be a [`Vector`] extension array with floating-point elements.
94+
pub fn compress(vector_array: ArrayRef) -> VortexResult<Self> {
95+
let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| {
96+
vortex_err!(
97+
"vector_array dtype must be an extension type, got {}",
98+
vector_array.dtype()
99+
)
100+
})?;
101+
102+
vortex_ensure!(
103+
ext.is::<Vector>(),
104+
"vector_array must have the Vector extension type, got {}",
105+
vector_array.dtype()
106+
);
107+
108+
let list_size = extension_list_size(ext)?;
109+
let row_count = vector_array.len();
110+
111+
// Compute L2 norms using the scalar function.
112+
let l2_norm_fn = ScalarFn::new(L2Norm, EmptyOptions).erased();
113+
let norms = ScalarFnArray::try_new(l2_norm_fn, vec![vector_array.clone()], row_count)?
114+
.to_primitive()
115+
.into_array();
116+
117+
// Divide each vector element by its corresponding norm.
118+
let storage = extension_storage(&vector_array)?;
119+
let flat = extract_flat_elements(&storage, list_size)?;
120+
let norms_prim = norms.to_canonical()?.into_primitive();
121+
122+
match_each_float_ptype!(flat.ptype(), |T| {
123+
let norms_slice = norms_prim.as_slice::<T>();
124+
125+
let normalized_elems: PrimitiveArray = (0..row_count)
126+
.flat_map(|i| {
127+
let inv_norm = safe_inv_norm(norms_slice[i]);
128+
flat.row::<T>(i).iter().map(move |&v| v * inv_norm)
129+
})
130+
.collect();
131+
132+
let fsl = FixedSizeListArray::new(
133+
normalized_elems.into_array(),
134+
u32::try_from(list_size)?,
135+
Validity::NonNullable,
136+
row_count,
137+
);
138+
139+
let ext_dtype =
140+
ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
141+
let normalized_vector = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array();
142+
143+
Self::try_new(normalized_vector, norms)
144+
})
145+
}
146+
73147
/// Returns a reference to the backing vector array that has been unit normalized.
74148
pub fn vector_array(&self) -> &ArrayRef {
75149
&self.vector_array
@@ -80,8 +154,58 @@ impl NormVectorArray {
80154
&self.norms
81155
}
82156

83-
// TODO docs
84-
pub(super) fn execute_into_vector(&self, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
85-
todo!()
157+
/// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm.
158+
pub fn decompress(&self, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
159+
let ext_dtype = self
160+
.vector_array
161+
.dtype()
162+
.as_extension_opt()
163+
.ok_or_else(|| {
164+
vortex_err!(
165+
"expected Vector extension dtype, got {}",
166+
self.vector_array.dtype()
167+
)
168+
})?;
169+
170+
let list_size = extension_list_size(ext_dtype)?;
171+
let row_count = self.vector_array.len();
172+
173+
let storage = extension_storage(&self.vector_array)?;
174+
let flat = extract_flat_elements(&storage, list_size)?;
175+
176+
let norms_prim = self.norms.to_canonical()?.into_primitive();
177+
178+
match_each_float_ptype!(flat.ptype(), |T| {
179+
let norms_slice = norms_prim.as_slice::<T>();
180+
181+
let result_elems: PrimitiveArray = (0..row_count)
182+
.flat_map(|i| {
183+
let norm = norms_slice[i];
184+
flat.row::<T>(i).iter().map(move |&v| v * norm)
185+
})
186+
.collect();
187+
188+
let fsl = FixedSizeListArray::new(
189+
result_elems.into_array(),
190+
u32::try_from(list_size)?,
191+
Validity::NonNullable,
192+
row_count,
193+
);
194+
195+
let ext_dtype =
196+
ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
197+
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
198+
})
199+
}
200+
}
201+
202+
/// Returns `1 / norm` if the norm is non-zero, or zero otherwise.
203+
///
204+
/// This avoids division by zero for zero-length or all-zero vectors.
205+
fn safe_inv_norm<T: Float>(norm: T) -> T {
206+
if norm == T::zero() {
207+
T::zero()
208+
} else {
209+
T::one() / norm
86210
}
87211
}

vortex-tensor/src/encodings/norm/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ pub use array::NormVectorArray;
99
mod vtable;
1010
pub use vtable::NormVector;
1111

12-
// #[cfg(test)]
13-
// mod tests;
12+
#[cfg(test)]
13+
mod tests;
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex::array::VortexSessionExecute;
5+
use vortex::array::arrays::Extension;
6+
use vortex::error::VortexResult;
7+
8+
use crate::encodings::norm::NormVectorArray;
9+
use crate::utils::extension_list_size;
10+
use crate::utils::extension_storage;
11+
use crate::utils::extract_flat_elements;
12+
use crate::utils::test_helpers::assert_close;
13+
use crate::utils::test_helpers::vector_array;
14+
15+
#[test]
16+
fn encode_unit_vectors() -> VortexResult<()> {
17+
// Already unit-length vectors: norms should be 1.0 and vectors unchanged.
18+
let arr = vector_array(
19+
3,
20+
&[
21+
1.0, 0.0, 0.0, // norm = 1.0
22+
0.0, 1.0, 0.0, // norm = 1.0
23+
],
24+
)?;
25+
26+
let norm = NormVectorArray::compress(arr)?;
27+
let norms = norm.norms().to_canonical()?.into_primitive();
28+
assert_close(norms.as_slice::<f64>(), &[1.0, 1.0]);
29+
30+
let vectors = norm.vector_array();
31+
let ext = vectors.dtype().as_extension_opt().unwrap();
32+
let list_size = extension_list_size(ext)?;
33+
let storage = extension_storage(vectors)?;
34+
let flat = extract_flat_elements(&storage, list_size)?;
35+
assert_close(flat.row::<f64>(0), &[1.0, 0.0, 0.0]);
36+
assert_close(flat.row::<f64>(1), &[0.0, 1.0, 0.0]);
37+
38+
Ok(())
39+
}
40+
41+
#[test]
42+
fn encode_non_unit_vectors() -> VortexResult<()> {
43+
let arr = vector_array(
44+
2,
45+
&[
46+
3.0, 4.0, // norm = 5.0
47+
0.0, 0.0, // norm = 0.0 (zero vector)
48+
],
49+
)?;
50+
51+
let norm = NormVectorArray::compress(arr)?;
52+
let norms = norm.norms().to_canonical()?.into_primitive();
53+
assert_close(norms.as_slice::<f64>(), &[5.0, 0.0]);
54+
55+
let vectors = norm.vector_array();
56+
let ext = vectors.dtype().as_extension_opt().unwrap();
57+
let list_size = extension_list_size(ext)?;
58+
let storage = extension_storage(vectors)?;
59+
let flat = extract_flat_elements(&storage, list_size)?;
60+
assert_close(flat.row::<f64>(0), &[3.0 / 5.0, 4.0 / 5.0]);
61+
assert_close(flat.row::<f64>(1), &[0.0, 0.0]);
62+
63+
Ok(())
64+
}
65+
66+
#[test]
67+
fn execute_round_trip() -> VortexResult<()> {
68+
let original_elements = &[
69+
3.0, 4.0, // norm = 5.0
70+
6.0, 8.0, // norm = 10.0
71+
];
72+
let arr = vector_array(2, original_elements)?;
73+
74+
let norm = NormVectorArray::compress(arr)?;
75+
76+
// Execute to reconstruct the original vectors.
77+
let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx();
78+
let reconstructed = norm.decompress(&mut ctx)?;
79+
80+
// The reconstructed array should be a Vector extension array.
81+
assert!(reconstructed.as_opt::<Extension>().is_some());
82+
83+
let ext = reconstructed.dtype().as_extension_opt().unwrap();
84+
let list_size = extension_list_size(ext)?;
85+
let storage = extension_storage(&reconstructed)?;
86+
let flat = extract_flat_elements(&storage, list_size)?;
87+
assert_close(flat.row::<f64>(0), &[3.0, 4.0]);
88+
assert_close(flat.row::<f64>(1), &[6.0, 8.0]);
89+
90+
Ok(())
91+
}
92+
93+
#[test]
94+
fn execute_round_trip_zero_vector() -> VortexResult<()> {
95+
let arr = vector_array(2, &[0.0, 0.0])?;
96+
97+
let norm = NormVectorArray::compress(arr)?;
98+
99+
let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx();
100+
let reconstructed = norm.decompress(&mut ctx)?;
101+
102+
let ext = reconstructed.dtype().as_extension_opt().unwrap();
103+
let list_size = extension_list_size(ext)?;
104+
let storage = extension_storage(&reconstructed)?;
105+
let flat = extract_flat_elements(&storage, list_size)?;
106+
// Zero vector should remain zero after round-trip.
107+
assert_close(flat.row::<f64>(0), &[0.0, 0.0]);
108+
109+
Ok(())
110+
}

vortex-tensor/src/encodings/norm/vtable/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,6 @@ impl VTable for NormVector {
162162
}
163163

164164
fn execute(array: &NormVectorArray, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
165-
Ok(ExecutionStep::Done(array.execute_into_vector(ctx)?))
165+
Ok(ExecutionStep::Done(array.decompress(ctx)?))
166166
}
167167
}

0 commit comments

Comments
 (0)