Skip to content

Commit 578d83d

Browse files
committed
address comments
Signed-off-by: Andrew Duffy <andrew@a10y.dev>
1 parent c07de7d commit 578d83d

File tree

3 files changed

+147
-48
lines changed

3 files changed

+147
-48
lines changed

vortex-array/src/arrays/patched/compute/compare.rs

Lines changed: 98 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ use vortex_error::VortexResult;
77

88
use crate::ArrayRef;
99
use crate::Canonical;
10+
use crate::DynArray;
1011
use crate::ExecutionCtx;
1112
use crate::IntoArray;
1213
use crate::arrays::BoolArray;
1314
use crate::arrays::ConstantArray;
1415
use crate::arrays::Patched;
1516
use crate::arrays::PrimitiveArray;
1617
use crate::arrays::bool::BoolArrayParts;
17-
use crate::arrays::patched::patch_lanes;
1818
use crate::arrays::primitive::NativeValue;
1919
use crate::builtins::ArrayBuiltins;
2020
use crate::dtype::NativePType;
@@ -39,8 +39,11 @@ impl CompareKernel for Patched {
3939
return Ok(None);
4040
};
4141

42+
// NOTE: due to offset, it's possible that the inner.len != array.len.
43+
// We slice the inner before performing the comparison.
4244
let result = lhs
4345
.inner
46+
.slice(lhs.offset..lhs.offset + lhs.len)?
4447
.binary(
4548
ConstantArray::new(constant.clone(), lhs.len()).into_array(),
4649
operator.into(),
@@ -57,46 +60,13 @@ impl CompareKernel for Patched {
5760

5861
let mut bits = BitBufferMut::from_buffer(bits.unwrap_host().into_mut(), offset, len);
5962

60-
fn apply<V: NativePType, F>(
61-
bits: &mut BitBufferMut,
62-
lane_offsets: &[u32],
63-
indices: &[u16],
64-
values: &[V],
65-
constant: V,
66-
cmp: F,
67-
) -> VortexResult<()>
68-
where
69-
F: Fn(V, V) -> bool,
70-
{
71-
let n_lanes = patch_lanes::<V>();
72-
73-
for index in 0..(lane_offsets.len() - 1) {
74-
let chunk = index / n_lanes;
75-
76-
let lane_start = lane_offsets[index] as usize;
77-
let lane_end = lane_offsets[index + 1] as usize;
78-
79-
for (&patch_index, &patch_value) in std::iter::zip(
80-
&indices[lane_start..lane_end],
81-
&values[lane_start..lane_end],
82-
) {
83-
let bit_index = chunk * 1024 + patch_index as usize;
84-
if cmp(patch_value, constant) {
85-
bits.set(bit_index)
86-
} else {
87-
bits.unset(bit_index)
88-
}
89-
}
90-
}
91-
92-
Ok(())
93-
}
94-
9563
let lane_offsets = lhs.lane_offsets.as_host().reinterpret::<u32>();
9664
let indices = lhs.indices.clone().execute::<PrimitiveArray>(ctx)?;
9765
let values = lhs.values.clone().execute::<PrimitiveArray>(ctx)?;
66+
let n_lanes = lhs.n_lanes;
9867

9968
match_each_native_ptype!(values.ptype(), |V| {
69+
let offset = lhs.offset;
10070
let indices = indices.as_slice::<u16>();
10171
let values = values.as_slice::<V>();
10272
let constant = constant
@@ -108,6 +78,8 @@ impl CompareKernel for Patched {
10878
CompareOperator::Eq => {
10979
apply::<V, _>(
11080
&mut bits,
81+
offset,
82+
n_lanes,
11183
lane_offsets,
11284
indices,
11385
values,
@@ -118,6 +90,8 @@ impl CompareKernel for Patched {
11890
CompareOperator::NotEq => {
11991
apply::<V, _>(
12092
&mut bits,
93+
offset,
94+
n_lanes,
12195
lane_offsets,
12296
indices,
12397
values,
@@ -128,6 +102,8 @@ impl CompareKernel for Patched {
128102
CompareOperator::Gt => {
129103
apply::<V, _>(
130104
&mut bits,
105+
n_lanes,
106+
offset,
131107
lane_offsets,
132108
indices,
133109
values,
@@ -138,6 +114,8 @@ impl CompareKernel for Patched {
138114
CompareOperator::Gte => {
139115
apply::<V, _>(
140116
&mut bits,
117+
n_lanes,
118+
offset,
141119
lane_offsets,
142120
indices,
143121
values,
@@ -148,6 +126,8 @@ impl CompareKernel for Patched {
148126
CompareOperator::Lt => {
149127
apply::<V, _>(
150128
&mut bits,
129+
n_lanes,
130+
offset,
151131
lane_offsets,
152132
indices,
153133
values,
@@ -158,6 +138,8 @@ impl CompareKernel for Patched {
158138
CompareOperator::Lte => {
159139
apply::<V, _>(
160140
&mut bits,
141+
n_lanes,
142+
offset,
161143
lane_offsets,
162144
indices,
163145
values,
@@ -173,11 +155,53 @@ impl CompareKernel for Patched {
173155
}
174156
}
175157

158+
#[allow(clippy::too_many_arguments)]
159+
fn apply<V: NativePType, F>(
160+
bits: &mut BitBufferMut,
161+
offset: usize,
162+
n_lanes: usize,
163+
lane_offsets: &[u32],
164+
indices: &[u16],
165+
values: &[V],
166+
constant: V,
167+
cmp: F,
168+
) -> VortexResult<()>
169+
where
170+
F: Fn(V, V) -> bool,
171+
{
172+
for index in 0..(lane_offsets.len() - 1) {
173+
let chunk = index / n_lanes;
174+
175+
let lane_start = lane_offsets[index] as usize;
176+
let lane_end = lane_offsets[index + 1] as usize;
177+
178+
for (&patch_index, &patch_value) in std::iter::zip(
179+
&indices[lane_start..lane_end],
180+
&values[lane_start..lane_end],
181+
) {
182+
let bit_index = chunk * 1024 + patch_index as usize;
183+
// Skip any indices < the offset.
184+
if bit_index < offset {
185+
continue;
186+
}
187+
let bit_index = bit_index - offset;
188+
if cmp(patch_value, constant) {
189+
bits.set(bit_index)
190+
} else {
191+
bits.unset(bit_index)
192+
}
193+
}
194+
}
195+
196+
Ok(())
197+
}
198+
176199
#[cfg(test)]
177200
mod tests {
178201
use vortex_buffer::buffer;
179202
use vortex_error::VortexResult;
180203

204+
use crate::DynArray;
181205
use crate::ExecutionCtx;
182206
use crate::IntoArray;
183207
use crate::LEGACY_SESSION;
@@ -187,6 +211,7 @@ mod tests {
187211
use crate::arrays::PatchedArray;
188212
use crate::arrays::PrimitiveArray;
189213
use crate::assert_arrays_eq;
214+
use crate::optimizer::ArrayOptimizer;
190215
use crate::patches::Patches;
191216
use crate::scalar_fn::fns::binary::CompareKernel;
192217
use crate::scalar_fn::fns::operators::CompareOperator;
@@ -220,6 +245,43 @@ mod tests {
220245
assert_arrays_eq!(expected, result);
221246
}
222247

248+
#[test]
249+
fn test_with_offset() {
250+
let lhs = PrimitiveArray::from_iter(0u32..512).into_array();
251+
let patches = Patches::new(
252+
512,
253+
0,
254+
buffer![5u16, 510, 511].into_array(),
255+
buffer![u32::MAX; 3].into_array(),
256+
None,
257+
)
258+
.unwrap();
259+
260+
let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone());
261+
262+
let lhs = PatchedArray::from_array_and_patches(lhs, &patches, &mut ctx).unwrap();
263+
// Slice the array so that the first patch should be skipped.
264+
let lhs = lhs
265+
.slice(10..lhs.len())
266+
.unwrap()
267+
.optimize()
268+
.unwrap()
269+
.try_into::<Patched>()
270+
.unwrap();
271+
272+
assert_eq!(lhs.len(), 502);
273+
274+
let rhs = ConstantArray::new(u32::MAX, lhs.len()).into_array();
275+
276+
let result = <Patched as CompareKernel>::compare(&lhs, &rhs, CompareOperator::Eq, &mut ctx)
277+
.unwrap()
278+
.unwrap();
279+
280+
let expected = BoolArray::from_indices(502, [500, 501], Validity::NonNullable).into_array();
281+
282+
assert_arrays_eq!(expected, result);
283+
}
284+
223285
#[test]
224286
fn test_subnormal_f32() -> VortexResult<()> {
225287
// Subnormal f32 values are smaller than f32::MIN_POSITIVE but greater than 0

vortex-array/src/arrays/patched/compute/filter.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ mod tests {
6363
use crate::arrays::PatchedArray;
6464
use crate::arrays::PrimitiveArray;
6565
use crate::assert_arrays_eq;
66+
use crate::optimizer::ArrayOptimizer;
6667
use crate::patches::Patches;
6768

6869
#[test]
@@ -142,4 +143,41 @@ mod tests {
142143

143144
Ok(())
144145
}
146+
147+
#[test]
148+
fn test_filter_sliced() -> VortexResult<()> {
149+
// Test filter on a sliced PatchedArray to exercise codepath where offset > 0.
150+
let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone());
151+
152+
// Create a larger array (6 chunks) so we can slice and still have room
153+
// for the filter to prune chunks.
154+
let array = buffer![u16::MIN; 6144].into_array();
155+
// Patches at indices 2048 and 2049 (start of chunk 2).
156+
let patched_indices = buffer![2048u16, 2049].into_array();
157+
let patched_values = buffer![u16::MAX, u16::MAX].into_array();
158+
159+
let patches = Patches::new(6144, 0, patched_indices, patched_values, None)?;
160+
161+
let patched = PatchedArray::from_array_and_patches(array, &patches, &mut ctx)?;
162+
163+
// Slice at chunk boundary to create offset > 0. After slicing [1024..5120],
164+
// we have 4096 elements and patches are at relative indices 1024 and 1025.
165+
let sliced = patched.slice(1024..5120)?.into_array();
166+
assert_eq!(sliced.len(), 4096);
167+
168+
// Filter that only touches the middle 2 chunks (chunks 1 and 2).
169+
// Indices 1024 and 1025 fall in chunk 1, and 3000 falls in chunk 2.
170+
let mask = Mask::from_indices(4096, vec![1024, 1025, 3000]);
171+
172+
let filtered = sliced
173+
.filter(mask)?
174+
.optimize()?
175+
.execute::<PrimitiveArray>(&mut ctx)?;
176+
177+
let expected = PrimitiveArray::from_iter([u16::MAX, u16::MAX, u16::MIN]);
178+
179+
assert_arrays_eq!(expected, filtered);
180+
181+
Ok(())
182+
}
145183
}

vortex-array/src/arrays/patched/vtable/mod.rs

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use crate::ExecutionResult;
2828
use crate::IntoArray;
2929
use crate::Precision;
3030
use crate::ProstMetadata;
31+
use crate::SerializeMetadata;
3132
use crate::arrays::PrimitiveArray;
3233
use crate::arrays::patched::PatchedArray;
3334
use crate::arrays::patched::compute::rules::PARENT_RULES;
@@ -45,10 +46,10 @@ use crate::serde::ArrayChildren;
4546
use crate::stats::ArrayStats;
4647
use crate::stats::StatsSetRef;
4748
use crate::vtable;
48-
use crate::vtable::ArrayId;
4949
use crate::vtable::VTable;
5050
use crate::vtable::ValidityChild;
5151
use crate::vtable::ValidityVTableFromChild;
52+
use crate::vtable::{Array, ArrayId};
5253

5354
vtable!(Patched);
5455

@@ -140,7 +141,7 @@ impl VTable for Patched {
140141
0 => array.inner.clone(),
141142
1 => array.indices.clone(),
142143
2 => array.values.clone(),
143-
_ => vortex_panic!("invalid buffer index for PatchedArray: {idx}"),
144+
_ => vortex_panic!("invalid child index for PatchedArray: {idx}"),
144145
}
145146
}
146147

@@ -149,7 +150,7 @@ impl VTable for Patched {
149150
0 => "inner".to_string(),
150151
1 => "patch_indices".to_string(),
151152
2 => "patch_values".to_string(),
152-
_ => vortex_panic!("invalid buffer index for PatchedArray: {idx}"),
153+
_ => vortex_panic!("invalid child index for PatchedArray: {idx}"),
153154
}
154155
}
155156

@@ -161,8 +162,8 @@ impl VTable for Patched {
161162
}))
162163
}
163164

164-
fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
165-
Ok(Some(vec![]))
165+
fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
166+
Ok(Some(metadata.serialize()))
166167
}
167168

168169
fn deserialize(
@@ -247,7 +248,7 @@ impl VTable for Patched {
247248

248249
let inner = children.get(0, dtype, len)?;
249250
let indices = children.get(1, PType::U16.into(), metadata.n_patches as usize)?;
250-
let values = children.get(1, dtype, metadata.n_patches as usize)?;
251+
let values = children.get(2, dtype, metadata.n_patches as usize)?;
251252

252253
Ok(PatchedArray {
253254
inner,
@@ -269,12 +270,13 @@ impl VTable for Patched {
269270
);
270271

271272
array.inner = children.remove(0);
273+
array.indices = children.remove(0);
272274
array.values = children.remove(0);
273275

274276
Ok(())
275277
}
276278

277-
fn execute(array: Arc<Self::Array>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
279+
fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
278280
let inner = array
279281
.inner
280282
.clone()
@@ -311,9 +313,6 @@ impl VTable for Patched {
311313
values.as_slice::<V>(),
312314
);
313315

314-
// The output will always be aligned to a chunk boundary, we apply the offset/len
315-
// at the end to slice to only the in-bounds values.
316-
let _output = output.as_slice();
317316
let output = output.freeze().slice(offset..offset + len);
318317

319318
PrimitiveArray::from_byte_buffer(output.into_byte_buffer(), ptype, validity)
@@ -323,7 +322,7 @@ impl VTable for Patched {
323322
}
324323

325324
fn execute_parent(
326-
array: &Self::Array,
325+
array: &Array<Self>,
327326
parent: &ArrayRef,
328327
child_idx: usize,
329328
ctx: &mut ExecutionCtx,
@@ -332,7 +331,7 @@ impl VTable for Patched {
332331
}
333332

334333
fn reduce_parent(
335-
array: &Self::Array,
334+
array: &Array<Self>,
336335
parent: &ArrayRef,
337336
child_idx: usize,
338337
) -> VortexResult<Option<ArrayRef>> {

0 commit comments

Comments
 (0)