@@ -19,8 +19,6 @@ use vortex_dtype::{
1919use vortex_error:: { VortexExpect , VortexResult , vortex_err, vortex_panic} ;
2020use vortex_fastlanes:: RLEArray ;
2121
22- use crate :: take:: cuda_take_kernel;
23-
2422pub fn cuda_rle_decompress ( rle : & RLEArray , ctx : Arc < CudaContext > ) -> VortexResult < ArrayRef > {
2523 match_each_native_ptype ! ( rle. values( ) . dtype( ) . as_ptype( ) , |V | {
2624 match_each_unsigned_integer_ptype!( rle. values( ) . dtype( ) . as_ptype( ) , |O | {
@@ -56,10 +54,11 @@ fn cuda_rle_decompress_typed<Values, Indices, Offsets>(
5654where
5755 Values : NativePType + DeviceRepr + ValidAsZeroBits ,
5856 Indices : UnsignedPType + DeviceRepr ,
57+ Offsets : UnsignedPType + DeviceRepr ,
5958{
6059 assert_eq ! ( indices. len( ) % 1024 , 0 ) ;
6160
62- let kernel_func = cuda_take_kernel :: < Indices , Values > ( false , ctx. clone ( ) ) ?;
61+ let kernel_func = cuda_rle_kernel :: < Indices , Values , Offsets > ( ctx. clone ( ) ) ?;
6362 let num_chunks =
6463 u32:: try_from ( indices. len ( ) . div_ceil ( 1024 ) ) . vortex_expect ( "num chunks overflow" ) ;
6564 let stream = ctx. default_stream ( ) ;
@@ -107,10 +106,7 @@ where
107106 Ok ( PrimitiveArray :: new ( buffer, Validity :: NonNullable ) . into_array ( ) )
108107}
109108
110- fn cuda_rle_kernel < Indices , Values , Offsets > (
111- mask : bool ,
112- ctx : Arc < CudaContext > ,
113- ) -> VortexResult < CudaFunction >
109+ fn cuda_rle_kernel < Indices , Values , Offsets > ( ctx : Arc < CudaContext > ) -> VortexResult < CudaFunction >
114110where
115111 Indices : UnsignedPType ,
116112 Values : NativePType ,
@@ -122,9 +118,9 @@ where
122118
123119 let kernel_name = format ! (
124120 "rle_decompress_i{}_v{}_o{}" ,
125- if mask { "_masked" } else { "" } ,
126121 & Indices :: PTYPE ,
127- & Values :: PTYPE & Offsets :: PTYPE ,
122+ & Values :: PTYPE ,
123+ & Offsets :: PTYPE ,
128124 ) ;
129125
130126 module
0 commit comments