@@ -7,14 +7,14 @@ use vortex_error::VortexResult;
77
88use crate :: ArrayRef ;
99use crate :: Canonical ;
10+ use crate :: DynArray ;
1011use crate :: ExecutionCtx ;
1112use crate :: IntoArray ;
1213use crate :: arrays:: BoolArray ;
1314use crate :: arrays:: ConstantArray ;
1415use crate :: arrays:: Patched ;
1516use crate :: arrays:: PrimitiveArray ;
1617use crate :: arrays:: bool:: BoolArrayParts ;
17- use crate :: arrays:: patched:: patch_lanes;
1818use crate :: arrays:: primitive:: NativeValue ;
1919use crate :: builtins:: ArrayBuiltins ;
2020use crate :: dtype:: NativePType ;
@@ -39,8 +39,11 @@ impl CompareKernel for Patched {
3939 return Ok ( None ) ;
4040 } ;
4141
42+ // NOTE: due to offset, it's possible that the inner.len != array.len.
43+ // We slice the inner before performing the comparison.
4244 let result = lhs
4345 . inner
46+ . slice ( lhs. offset ..lhs. offset + lhs. len ) ?
4447 . binary (
4548 ConstantArray :: new ( constant. clone ( ) , lhs. len ( ) ) . into_array ( ) ,
4649 operator. into ( ) ,
@@ -57,46 +60,13 @@ impl CompareKernel for Patched {
5760
5861 let mut bits = BitBufferMut :: from_buffer ( bits. unwrap_host ( ) . into_mut ( ) , offset, len) ;
5962
60- fn apply < V : NativePType , F > (
61- bits : & mut BitBufferMut ,
62- lane_offsets : & [ u32 ] ,
63- indices : & [ u16 ] ,
64- values : & [ V ] ,
65- constant : V ,
66- cmp : F ,
67- ) -> VortexResult < ( ) >
68- where
69- F : Fn ( V , V ) -> bool ,
70- {
71- let n_lanes = patch_lanes :: < V > ( ) ;
72-
73- for index in 0 ..( lane_offsets. len ( ) - 1 ) {
74- let chunk = index / n_lanes;
75-
76- let lane_start = lane_offsets[ index] as usize ;
77- let lane_end = lane_offsets[ index + 1 ] as usize ;
78-
79- for ( & patch_index, & patch_value) in std:: iter:: zip (
80- & indices[ lane_start..lane_end] ,
81- & values[ lane_start..lane_end] ,
82- ) {
83- let bit_index = chunk * 1024 + patch_index as usize ;
84- if cmp ( patch_value, constant) {
85- bits. set ( bit_index)
86- } else {
87- bits. unset ( bit_index)
88- }
89- }
90- }
91-
92- Ok ( ( ) )
93- }
94-
9563 let lane_offsets = lhs. lane_offsets . as_host ( ) . reinterpret :: < u32 > ( ) ;
9664 let indices = lhs. indices . clone ( ) . execute :: < PrimitiveArray > ( ctx) ?;
9765 let values = lhs. values . clone ( ) . execute :: < PrimitiveArray > ( ctx) ?;
66+ let n_lanes = lhs. n_lanes ;
9867
9968 match_each_native_ptype ! ( values. ptype( ) , |V | {
69+ let offset = lhs. offset;
10070 let indices = indices. as_slice:: <u16 >( ) ;
10171 let values = values. as_slice:: <V >( ) ;
10272 let constant = constant
@@ -108,6 +78,8 @@ impl CompareKernel for Patched {
10878 CompareOperator :: Eq => {
10979 apply:: <V , _>(
11080 & mut bits,
81+ offset,
82+ n_lanes,
11183 lane_offsets,
11284 indices,
11385 values,
@@ -118,6 +90,8 @@ impl CompareKernel for Patched {
11890 CompareOperator :: NotEq => {
11991 apply:: <V , _>(
12092 & mut bits,
93+ offset,
94+ n_lanes,
12195 lane_offsets,
12296 indices,
12397 values,
@@ -128,6 +102,8 @@ impl CompareKernel for Patched {
128102 CompareOperator :: Gt => {
129103 apply:: <V , _>(
130104 & mut bits,
105+ n_lanes,
106+ offset,
131107 lane_offsets,
132108 indices,
133109 values,
@@ -138,6 +114,8 @@ impl CompareKernel for Patched {
138114 CompareOperator :: Gte => {
139115 apply:: <V , _>(
140116 & mut bits,
117+ n_lanes,
118+ offset,
141119 lane_offsets,
142120 indices,
143121 values,
@@ -148,6 +126,8 @@ impl CompareKernel for Patched {
148126 CompareOperator :: Lt => {
149127 apply:: <V , _>(
150128 & mut bits,
129+ n_lanes,
130+ offset,
151131 lane_offsets,
152132 indices,
153133 values,
@@ -158,6 +138,8 @@ impl CompareKernel for Patched {
158138 CompareOperator :: Lte => {
159139 apply:: <V , _>(
160140 & mut bits,
141+ n_lanes,
142+ offset,
161143 lane_offsets,
162144 indices,
163145 values,
@@ -173,11 +155,53 @@ impl CompareKernel for Patched {
173155 }
174156}
175157
158+ #[ allow( clippy:: too_many_arguments) ]
159+ fn apply < V : NativePType , F > (
160+ bits : & mut BitBufferMut ,
161+ offset : usize ,
162+ n_lanes : usize ,
163+ lane_offsets : & [ u32 ] ,
164+ indices : & [ u16 ] ,
165+ values : & [ V ] ,
166+ constant : V ,
167+ cmp : F ,
168+ ) -> VortexResult < ( ) >
169+ where
170+ F : Fn ( V , V ) -> bool ,
171+ {
172+ for index in 0 ..( lane_offsets. len ( ) - 1 ) {
173+ let chunk = index / n_lanes;
174+
175+ let lane_start = lane_offsets[ index] as usize ;
176+ let lane_end = lane_offsets[ index + 1 ] as usize ;
177+
178+ for ( & patch_index, & patch_value) in std:: iter:: zip (
179+ & indices[ lane_start..lane_end] ,
180+ & values[ lane_start..lane_end] ,
181+ ) {
182+ let bit_index = chunk * 1024 + patch_index as usize ;
183+ // Skip any indices < the offset.
184+ if bit_index < offset {
185+ continue ;
186+ }
187+ let bit_index = bit_index - offset;
188+ if cmp ( patch_value, constant) {
189+ bits. set ( bit_index)
190+ } else {
191+ bits. unset ( bit_index)
192+ }
193+ }
194+ }
195+
196+ Ok ( ( ) )
197+ }
198+
176199#[ cfg( test) ]
177200mod tests {
178201 use vortex_buffer:: buffer;
179202 use vortex_error:: VortexResult ;
180203
204+ use crate :: DynArray ;
181205 use crate :: ExecutionCtx ;
182206 use crate :: IntoArray ;
183207 use crate :: LEGACY_SESSION ;
@@ -187,6 +211,7 @@ mod tests {
187211 use crate :: arrays:: PatchedArray ;
188212 use crate :: arrays:: PrimitiveArray ;
189213 use crate :: assert_arrays_eq;
214+ use crate :: optimizer:: ArrayOptimizer ;
190215 use crate :: patches:: Patches ;
191216 use crate :: scalar_fn:: fns:: binary:: CompareKernel ;
192217 use crate :: scalar_fn:: fns:: operators:: CompareOperator ;
@@ -220,6 +245,43 @@ mod tests {
220245 assert_arrays_eq ! ( expected, result) ;
221246 }
222247
248+ #[ test]
249+ fn test_with_offset ( ) {
250+ let lhs = PrimitiveArray :: from_iter ( 0u32 ..512 ) . into_array ( ) ;
251+ let patches = Patches :: new (
252+ 512 ,
253+ 0 ,
254+ buffer ! [ 5u16 , 510 , 511 ] . into_array ( ) ,
255+ buffer ! [ u32 :: MAX ; 3 ] . into_array ( ) ,
256+ None ,
257+ )
258+ . unwrap ( ) ;
259+
260+ let mut ctx = ExecutionCtx :: new ( LEGACY_SESSION . clone ( ) ) ;
261+
262+ let lhs = PatchedArray :: from_array_and_patches ( lhs, & patches, & mut ctx) . unwrap ( ) ;
263+ // Slice the array so that the first patch should be skipped.
264+ let lhs = lhs
265+ . slice ( 10 ..lhs. len ( ) )
266+ . unwrap ( )
267+ . optimize ( )
268+ . unwrap ( )
269+ . try_into :: < Patched > ( )
270+ . unwrap ( ) ;
271+
272+ assert_eq ! ( lhs. len( ) , 502 ) ;
273+
274+ let rhs = ConstantArray :: new ( u32:: MAX , lhs. len ( ) ) . into_array ( ) ;
275+
276+ let result = <Patched as CompareKernel >:: compare ( & lhs, & rhs, CompareOperator :: Eq , & mut ctx)
277+ . unwrap ( )
278+ . unwrap ( ) ;
279+
280+ let expected = BoolArray :: from_indices ( 502 , [ 500 , 501 ] , Validity :: NonNullable ) . into_array ( ) ;
281+
282+ assert_arrays_eq ! ( expected, result) ;
283+ }
284+
223285 #[ test]
224286 fn test_subnormal_f32 ( ) -> VortexResult < ( ) > {
225287 // Subnormal f32 values are smaller than f32::MIN_POSITIVE but greater than 0
0 commit comments