Skip to content

Commit 13dd7d4

Browse files
committed
fix scalars
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent cf65eae commit 13dd7d4

File tree

2 files changed

+77
-1
lines changed

2 files changed

+77
-1
lines changed

vortex-tensor/src/encodings/norm/tests.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use vortex::array::IntoArray;
45
use vortex::array::VortexSessionExecute;
56
use vortex::array::arrays::Extension;
67
use vortex::error::VortexResult;
@@ -108,3 +109,27 @@ fn execute_round_trip_zero_vector() -> VortexResult<()> {
108109

109110
Ok(())
110111
}
112+
113+
#[test]
114+
fn scalar_at_returns_original_vector() -> VortexResult<()> {
115+
let arr = vector_array(
116+
2,
117+
&[
118+
3.0, 4.0, // norm = 5.0
119+
6.0, 8.0, // norm = 10.0
120+
],
121+
)?;
122+
123+
let encoded = NormVectorArray::compress(arr)?;
124+
125+
// `scalar_at` on the NormVectorArray should match `scalar_at` on the decompressed result.
126+
let mut ctx = vortex::array::LEGACY_SESSION.create_execution_ctx();
127+
let decompressed = encoded.decompress(&mut ctx)?;
128+
129+
let norm_array = encoded.into_array();
130+
for i in 0..2 {
131+
assert_eq!(norm_array.scalar_at(i)?, decompressed.scalar_at(i)?);
132+
}
133+
134+
Ok(())
135+
}
Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,66 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use vortex::array::IntoArray;
5+
use vortex::array::arrays::ConstantArray;
6+
use vortex::array::arrays::FixedSizeList;
7+
use vortex::array::builtins::ArrayBuiltins;
48
use vortex::array::vtable::OperationsVTable;
9+
use vortex::dtype::Nullability;
510
use vortex::error::VortexResult;
11+
use vortex::error::vortex_err;
612
use vortex::scalar::Scalar;
13+
use vortex::scalar_fn::fns::operators::Operator;
714

815
use crate::encodings::norm::array::NormVectorArray;
916
use crate::encodings::norm::vtable::NormVector;
17+
use crate::utils::extension_list_size;
18+
use crate::utils::extension_storage;
1019

1120
impl OperationsVTable<NormVector> for NormVector {
1221
fn scalar_at(array: &NormVectorArray, index: usize) -> VortexResult<Scalar> {
13-
array.vector_array().scalar_at(index)
22+
let ext = array
23+
.vector_array()
24+
.dtype()
25+
.as_extension_opt()
26+
.ok_or_else(|| {
27+
vortex_err!(
28+
"expected Vector extension dtype, got {}",
29+
array.vector_array().dtype()
30+
)
31+
})?;
32+
let list_size = extension_list_size(ext)?;
33+
34+
// Get the storage (FixedSizeList) and slice out the elements for this row.
35+
let storage = extension_storage(array.vector_array())?;
36+
let fsl = storage
37+
.as_opt::<FixedSizeList>()
38+
.ok_or_else(|| vortex_err!("expected FixedSizeList storage"))?;
39+
let row_elements = fsl.fixed_size_list_elements_at(index)?;
40+
41+
// Multiply all elements by the norm using a ConstantArray broadcast.
42+
let norm_scalar = array.norms().scalar_at(index)?;
43+
let norm_broadcast = ConstantArray::new(norm_scalar, list_size).into_array();
44+
let scaled = row_elements.binary(norm_broadcast, Operator::Mul)?;
45+
46+
// Rebuild the FSL scalar, then wrap in the extension type.
47+
let element_dtype = ext
48+
.storage_dtype()
49+
.as_fixed_size_list_element_opt()
50+
.ok_or_else(|| {
51+
vortex_err!(
52+
"expected FixedSizeList storage dtype, got {}",
53+
ext.storage_dtype()
54+
)
55+
})?;
56+
57+
let children: Vec<Scalar> = (0..list_size)
58+
.map(|i| scaled.scalar_at(i))
59+
.collect::<VortexResult<_>>()?;
60+
61+
let fsl_scalar =
62+
Scalar::fixed_size_list(element_dtype.clone(), children, Nullability::NonNullable);
63+
64+
Ok(Scalar::extension_ref(ext.clone(), fsl_scalar))
1465
}
1566
}

0 commit comments

Comments
 (0)